Skip to content

Commit

Permalink
hotfix: go for FeaturesType instead of InputData in a pipeline tuning (
Browse files Browse the repository at this point in the history
  • Loading branch information
Lopa10ko authored Jul 24, 2024
1 parent a7e4243 commit d33ca9e
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions fedot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def fit(self,
return self.current_pipeline

def tune(self,
input_data: Optional[InputData] = None,
input_data: Optional[FeaturesType] = None,
target: TargetType = 'target',
metric_name: Optional[Union[str, MetricCallable]] = None,
iterations: int = DEFAULT_TUNING_ITERATIONS_NUMBER,
timeout: Optional[float] = None,
Expand All @@ -212,7 +213,8 @@ def tune(self,
"""Method for hyperparameters tuning of current pipeline
Args:
input_data: data for tuning pipeline.
input_data: data for tuning pipeline in one of the supported formats.
target: data target values in one of the supported target formats.
metric_name: name of metric for quality tuning.
iterations: numbers of tuning iterations.
timeout: time for tuning (in minutes). If ``None`` or ``-1`` means tuning until max iteration reach.
Expand All @@ -227,7 +229,10 @@ def tune(self,
raise ValueError(NOT_FITTED_ERR_MSG)

with fedot_composer_timer.launch_tuning('post'):
input_data = input_data or self.train_data
if input_data is None:
input_data = self.train_data
else:
input_data = self.data_processor.define_data(features=input_data, target=target, is_predict=False)
cv_folds = cv_folds or self.params.get('cv_folds')
n_jobs = n_jobs or self.params.n_jobs

Expand Down

0 comments on commit d33ca9e

Please sign in to comment.