Skip to content

Commit

Permalink
Adding preprocessing copying to predefined models
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Aug 20, 2024
1 parent f9f8acf commit fca7ef6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
11 changes: 8 additions & 3 deletions fedot/api/api_utils/api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@ def __init__(self, task: Task, use_input_preprocessing: bool = True):
self.task = task

self._recommendations = {}
self.preprocessor = DummyPreprocessor()

if use_input_preprocessing:
self.preprocessor = DataPreprocessor()

# Dictionary with recommendations (e.g. 'cut' for cutting dataset, 'label_encoded'
# to encode features using label encoder). Parameters for transformation provided also
self._recommendations = {'cut': self.preprocessor.cut_dataset,
'label_encoded': self.preprocessor.label_encoding_for_fit}
self._recommendations = {
'cut': self.preprocessor.cut_dataset,
'label_encoded': self.preprocessor.label_encoding_for_fit
}

else:
self.preprocessor = DummyPreprocessor()

self.log = default_log(self)

Expand Down
15 changes: 12 additions & 3 deletions fedot/api/api_utils/predefined_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,35 @@
from fedot.core.pipelines.node import PipelineNode
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.pipelines.verification import verify_pipeline
from fedot.preprocessing.base_preprocessing import BasePreprocessor


class PredefinedModel:
def __init__(self, predefined_model: Union[str, Pipeline], data: InputData, log: LoggerAdapter,
use_input_preprocessing: bool = True):
use_input_preprocessing: bool = True, api_preprocessor: BasePreprocessor = None):
self.predefined_model = predefined_model
self.data = data
self.log = log
self.pipeline = self._get_pipeline(use_input_preprocessing)
self.pipeline = self._get_pipeline(use_input_preprocessing, api_preprocessor)

def _get_pipeline(self, use_input_preprocessing: bool = True) -> Pipeline:
def _get_pipeline(self, use_input_preprocessing: bool = True, api_preprocessor: BasePreprocessor = None) -> Pipeline:
if isinstance(self.predefined_model, Pipeline):
pipelines = self.predefined_model
elif self.predefined_model == 'auto':
# Generate initial assumption automatically
pipelines = AssumptionsBuilder.get(self.data).from_operations().build(
use_input_preprocessing=use_input_preprocessing)[0]

if use_input_preprocessing and api_preprocessor is not None:
pipelines.preprocessor = api_preprocessor

elif isinstance(self.predefined_model, str):
model = PipelineNode(self.predefined_model)
pipelines = Pipeline(model, use_input_preprocessing=use_input_preprocessing)

if use_input_preprocessing and api_preprocessor is not None:
pipelines.preprocessor = api_preprocessor

else:
raise ValueError(f'{type(self.predefined_model)} is not supported as Fedot model')

Expand Down
8 changes: 5 additions & 3 deletions fedot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def fit(self,
with fedot_composer_timer.launch_fitting():
if predefined_model is not None:
# Fit predefined model and return it without composing
self.current_pipeline = PredefinedModel(predefined_model, self.train_data, self.log,
use_input_preprocessing=self.params.get(
'use_input_preprocessing')).fit()
self.current_pipeline = PredefinedModel(
predefined_model, self.train_data, self.log,
use_input_preprocessing=self.params.get('use_input_preprocessing'),
api_preprocessor=self.data_processor.preprocessor,
).fit()
else:
self.current_pipeline, self.best_models, self.history = self.api_composer.obtain_model(self.train_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _reasonability_check(features):
# For every column in table make check
for column_id in range(0, columns_amount):
column = features[:, column_id] if columns_amount > 1 else features.copy()
if len(np.unique(column)) > 2:
if len(set(column)) > 2:
non_bool_ids.append(column_id)
else:
bool_ids.append(column_id)
Expand Down

0 comments on commit fca7ef6

Please sign in to comment.