diff --git a/fedot/api/main.py b/fedot/api/main.py index e607eb088f..f389489acc 100644 --- a/fedot/api/main.py +++ b/fedot/api/main.py @@ -202,7 +202,8 @@ def fit(self, return self.current_pipeline def tune(self, - input_data: Optional[InputData] = None, + input_data: Optional[FeaturesType] = None, + target: TargetType = 'target', metric_name: Optional[Union[str, MetricCallable]] = None, iterations: int = DEFAULT_TUNING_ITERATIONS_NUMBER, timeout: Optional[float] = None, @@ -212,7 +213,8 @@ def tune(self, """Method for hyperparameters tuning of current pipeline Args: - input_data: data for tuning pipeline. + input_data: data for tuning pipeline in one of the supported formats. + target: data target values in one of the supported target formats. metric_name: name of metric for quality tuning. iterations: numbers of tuning iterations. timeout: time for tuning (in minutes). If ``None`` or ``-1`` means tuning until max iteration reach. @@ -227,7 +229,10 @@ def tune(self, raise ValueError(NOT_FITTED_ERR_MSG) with fedot_composer_timer.launch_tuning('post'): - input_data = input_data or self.train_data + if input_data is None: + input_data = self.train_data + else: + input_data = self.data_processor.define_data(features=input_data, target=target, is_predict=False) cv_folds = cv_folds or self.params.get('cv_folds') n_jobs = n_jobs or self.params.n_jobs