Skip to content

Commit

Permalink
add initial pipeline to multiobjective test
Browse files Browse the repository at this point in the history
  • Loading branch information
IIaKyJIuH committed Jul 12, 2023
1 parent 058a0f7 commit a0fcb5f
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions test/integration/quality/test_quality_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np

from fedot.api.main import Fedot
from fedot.core.pipelines.node import PipelineNode
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.utils import fedot_project_root

warnings.filterwarnings("ignore")
Expand All @@ -19,13 +21,13 @@ def test_classification_quality_improvement():
problem = 'classification'
with_tuning = False

common_params= dict(problem=problem,
n_jobs=1,
use_pipelines_cache=False,
use_preprocessing_cache=False,
with_tuning=with_tuning,
logging_level=logging.DEBUG,
seed=seed)
common_params = dict(problem=problem,
n_jobs=1,
use_pipelines_cache=False,
use_preprocessing_cache=False,
with_tuning=with_tuning,
logging_level=logging.DEBUG,
seed=seed)

expected_baseline_quality = 0.750
baseline_model = Fedot(**common_params)
Expand Down Expand Up @@ -56,13 +58,22 @@ def test_multiobjective_improvement():
metrics = [quality_metric, complexity_metric]

timeout = 2
composer_params = dict(num_of_generations=10,
composer_params = dict(num_of_generations=5,
pop_size=3,
with_tuning=False,
preset='fast_train',
preset='best_quality',
metric=metrics)

initial_pipeline = Pipeline(
PipelineNode('logit',
nodesfrom=[
PipelineNode('rf', nodes_from=[PipelineNode('rf')]),
PipelineNode('rf', nodes_from=[PipelineNode('rf')])
])
)

auto_model = Fedot(problem=problem, timeout=timeout, seed=seed, logging_level=logging.DEBUG,
initial_assumption=initial_pipeline,
**composer_params, n_jobs=1, use_pipelines_cache=False, use_preprocessing_cache=False)
auto_model.fit(features=train_data_path, target='target')
auto_model.predict_proba(features=test_data_path)
Expand All @@ -71,7 +82,7 @@ def test_multiobjective_improvement():
quality_improved, complexity_improved = check_improvement(auto_model.history)

assert auto_metrics[quality_metric] > 0.75
assert auto_metrics[complexity_metric] >= 0.2
assert auto_metrics[complexity_metric] <= 0.2
assert quality_improved
assert complexity_improved

Expand Down

0 comments on commit a0fcb5f

Please sign in to comment.