Skip to content

Commit

Permalink
fix problems with lagged window
Browse files Browse the repository at this point in the history
  • Loading branch information
kasyanovse committed Oct 19, 2023
1 parent 9cd08a0 commit 7d38ba0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/integration/models/test_custom_model_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_simple_pipeline(multi_data):
exog_list.append(PipelineNode(data_id))
if 'hist_' in data_id:
lagged_node = PipelineNode('lagged', nodes_from=[PipelineNode(data_id)])
lagged_node.parameters = {'window_size': 1}
lagged_node.parameters = {'window_size': 2}

hist_list.append(lagged_node)

Expand Down
16 changes: 16 additions & 0 deletions test/integration/pipelines/tuning/test_pipeline_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from time import time

import pytest

from fedot.core.repository.dataset_types import DataTypesEnum
from golem.core.tuning.hyperopt_tuner import get_node_parameters_for_hyperopt
from golem.core.tuning.iopt_tuner import IOptTuner
from golem.core.tuning.optuna_tuner import OptunaTuner
Expand Down Expand Up @@ -216,6 +218,20 @@ def run_pipeline_tuner(train_data,
cv=None,
iterations=3,
early_stopping_rounds=None, **kwargs):

if train_data.data_type in (DataTypesEnum.ts, DataTypesEnum.multi_ts):
forecast_length = train_data.task.task_params.forecast_length
folds = cv or 1
validation_blocks = 1
max_window = int(train_data.features.shape[0] / (folds + 1)) - (forecast_length * validation_blocks) - 1
ssp = {'window_size': {'hyperopt-dist': hp.uniformint, 'sampling-scope': [2, max_window], 'type': 'discrete'}}
if search_space.custom_search_space is None:
search_space.custom_search_space = {'lagged': ssp}
else:
search_space.custom_search_space['lagged'] = ssp
search_space.replace_default_search_space = True
search_space.parameters_per_operation = search_space.get_parameters_dict()

# Pipeline tuning
pipeline_tuner = TunerBuilder(train_data.task) \
.with_tuner(tuner) \
Expand Down
6 changes: 6 additions & 0 deletions test/integration/real_applications/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from datetime import timedelta

import numpy as np
Expand Down Expand Up @@ -45,8 +47,10 @@ def test_gapfilling_example():

def test_exogenous_ts_example():
path = fedot_project_root().joinpath('test/data/simple_sea_level.csv')
test = os.environ.pop('PYTEST_CURRENT_TEST')
run_exogenous_experiment(path_to_file=path,
len_forecast=50, with_exog=True)
os.environ['PYTEST_CURRENT_TEST'] = test


def test_nemo_multiple_points_example():
Expand Down Expand Up @@ -84,7 +88,9 @@ def test_api_example():
prediction = run_classification_example(timeout=1, with_tuning=with_tuning)
assert prediction is not None

test = os.environ.pop('PYTEST_CURRENT_TEST')
forecast = run_ts_forecasting_example(dataset='australia', timeout=2, with_tuning=with_tuning)
os.environ['PYTEST_CURRENT_TEST'] = test
assert forecast is not None

pareto = run_classification_multiobj_example(timeout=1, with_tuning=with_tuning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def get_fitted_fedot(forecast_length, train_data, **kwargs):
'seed': 1,
'timeout': None,
'pop_size': 50,
'num_of_generations': 5}
'num_of_generations': 5,
'with_tuning': False}
params.update(kwargs)
fedot = Fedot(**params)
fedot.fit(train_data)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/utilities/test_pipeline_import_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_export_without_path_correctly():


def test_data_model_types_forecasting_pipeline_fit():
train_data, test_data = get_ts_data(forecast_length=10)
train_data, test_data = get_ts_data(n_steps = 200, forecast_length=10)

pipeline = get_multiscale_pipeline()
pipeline.fit(train_data)
Expand Down

0 comments on commit 7d38ba0

Please sign in to comment.