Skip to content

Commit

Permalink
feat: speed up models tests by parametrizing
Browse files Browse the repository at this point in the history
  • Loading branch information
Lopa10ko committed Jul 23, 2024
1 parent a7c4022 commit 04d9a15
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions test/integration/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,56 +461,55 @@ def test_locf_forecast_correctly():
assert np.array_equal(predict_forecast.predict, np.array([[110, 120, 130, 110]]))


def test_models_does_not_fall_on_constant_data():
@pytest.mark.parametrize('operation', OperationTypesRepository('all')._repo, ids=lambda x: x.id)
def test_models_does_not_fall_on_constant_data(operation):
""" Run models on constant data """
# models that raise exception
to_skip = ['custom', 'arima', 'catboost', 'catboostreg', 'cgru',
'lda', 'fast_ica', 'decompose', 'class_decompose']
if operation.id in to_skip:
return

for operation in OperationTypesRepository('all')._repo:
if operation.id in to_skip:
continue
for task_type in operation.task_type:
for data_type in operation.input_types:
data = get_data_for_testing(task_type, data_type,
length=100, features_count=2,
random=False)
if data is not None:
for task_type in operation.task_type:
for data_type in operation.input_types:
data = get_data_for_testing(task_type, data_type,
length=100, features_count=2,
random=False)
if data is not None:
nodes_from = []
if task_type is TaskTypesEnum.ts_forecasting:
if 'non_lagged' not in operation.tags:
nodes_from = [PipelineNode('lagged')]
node = PipelineNode(operation.id, nodes_from=nodes_from)
pipeline = Pipeline(node)
pipeline.fit(data)
assert pipeline.predict(data) is not None


@pytest.mark.parametrize('operation', OperationTypesRepository('all')._repo, ids=lambda x: x.id)
def test_operations_are_serializable(operation):
to_skip = ['custom', 'decompose', 'class_decompose']
if operation.id in to_skip:
return

for task_type in operation.task_type:
for data_type in operation.input_types:
data = get_data_for_testing(task_type, data_type,
length=100, features_count=2,
random=True)
if data is not None:
try:
nodes_from = []
if task_type is TaskTypesEnum.ts_forecasting:
if 'non_lagged' not in operation.tags:
nodes_from = [PipelineNode('lagged')]
node = PipelineNode(operation.id, nodes_from=nodes_from)
pipeline = Pipeline(node)
pipeline.fit(data)
assert pipeline.predict(data) is not None


def test_operations_are_serializable():
to_skip = ['custom', 'decompose', 'class_decompose']

for operation in OperationTypesRepository('all')._repo:
if operation.id in to_skip:
continue
for task_type in operation.task_type:
for data_type in operation.input_types:
data = get_data_for_testing(task_type, data_type,
length=100, features_count=2,
random=True)
if data is not None:
try:
nodes_from = []
if task_type is TaskTypesEnum.ts_forecasting:
if 'non_lagged' not in operation.tags:
nodes_from = [PipelineNode('lagged')]
node = PipelineNode(operation.id, nodes_from=nodes_from)
pipeline = Pipeline(node)
pipeline.fit(data)
serialized = pickle.dumps(pipeline, pickle.HIGHEST_PROTOCOL)
assert isinstance(serialized, bytes)
except NotImplementedError:
pass
serialized = pickle.dumps(pipeline, pickle.HIGHEST_PROTOCOL)
assert isinstance(serialized, bytes)
except NotImplementedError:
pass


def test_operations_are_fast():
Expand All @@ -534,7 +533,7 @@ def test_operations_are_fast():
reference_time = tuple(map(min, zip(perfomance_values, reference_time)))

for operation in OperationTypesRepository('all')._repo:
if (operation.id not in to_skip and operation.presets and FAST_TRAIN_PRESET_NAME in operation.presets):
if operation.id not in to_skip and operation.presets and FAST_TRAIN_PRESET_NAME in operation.presets:
for _ in range(attempt):
perfomance_values = get_operation_perfomance(operation, data_lengths)
# if attempt is successful then stop
Expand All @@ -548,7 +547,6 @@ def test_all_operations_are_documented():
# All operations and presets should be listed in `docs/source/introduction/fedot_features/automation_features.rst`
to_skip = {'custom', 'data_source_img', 'data_source_text', 'data_source_table', 'data_source_ts', 'exog_ts'}
path_to_docs = fedot_project_root() / 'docs/source/introduction/fedot_features/automation_features.rst'
docs_lines = None

with open(path_to_docs, 'r') as docs_:
docs_lines = docs_.readlines()
Expand Down

0 comments on commit 04d9a15

Please sign in to comment.