Skip to content

Commit

Permalink
628 bugfix (#1145)
Browse files Browse the repository at this point in the history
* Fix path to data

* Fix tests
  • Loading branch information
kasyanovse authored Aug 22, 2023
1 parent 634de80 commit 8cb0468
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion cases/multi_target_levels_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fedot.api.main import Fedot
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.utils import fedot_project_root

warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -99,5 +100,5 @@ def run_multi_output_case(path, vis=False):


if __name__ == '__main__':
path_file = './data/lena_levels/multi_sample.csv'
path_file = fedot_project_root() / 'cases/data/lena_levels/multi_sample.csv'
run_multi_output_case(path_file, vis=True)
2 changes: 0 additions & 2 deletions cases/multi_ts_level_forecasting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error

from examples.simple.time_series_forecasting.ts_pipelines import ts_complex_ridge_smoothing_pipeline
Expand Down Expand Up @@ -52,7 +51,6 @@ def run_multi_ts_forecast(forecast_length, is_multi_ts):
pop_size=15,
max_arity=4,
cv_folds=None,
validation_blocks=None,
initial_assumption=init_pipeline
)
# fit model
Expand Down
9 changes: 7 additions & 2 deletions test/integration/composer/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import numpy as np
import pytest

from fedot.core.repository.tasks import TaskTypesEnum
from golem.core.dag.graph import Graph
from golem.core.optimisers.fitness import SingleObjFitness
from golem.core.optimisers.genetic.evaluation import MultiprocessingDispatcher
Expand Down Expand Up @@ -118,8 +120,11 @@ def test_collect_intermediate_metric(pipeline: Pipeline, input_data: InputData,
graph_gen_params = get_pipeline_generation_params()
metrics = [metric]

data_source = DataSourceSplitter(validation_blocks=1).build(input_data)
objective_eval = PipelineObjectiveEvaluate(MetricsObjective(metrics), data_source)
validation_blocks = 1 if input_data.task.task_type is TaskTypesEnum.ts_forecasting else None
data_source = DataSourceSplitter(validation_blocks=validation_blocks).build(input_data)
objective_eval = PipelineObjectiveEvaluate(MetricsObjective(metrics),
data_source,
validation_blocks=validation_blocks)
dispatcher = MultiprocessingDispatcher(graph_gen_params.adapter)
dispatcher.set_graph_evaluation_callback(objective_eval.evaluate_intermediate_metrics)
evaluate = dispatcher.dispatch(objective_eval)
Expand Down

0 comments on commit 8cb0468

Please sign in to comment.