Skip to content

Commit

Permalink
Requested YamLyubov's changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Sep 11, 2023
1 parent 8680db2 commit be3faab
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
40 changes: 15 additions & 25 deletions fedot/core/operations/evaluation/boostings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from fedot.utilities.random import ImplementationRandomStateHandler


class BoostingClassificationStrategy(EvaluationStrategy):
class BoostingStrategy(EvaluationStrategy):
__operations_by_types = {
'catboost': FedotCatBoostClassificationImplementation
'catboost': FedotCatBoostClassificationImplementation,

'catboostreg': FedotCatBoostRegressionImplementation
}

def __init__(self, operation_type: str, params: Optional[OperationParameters] = None):
Expand All @@ -32,43 +34,31 @@ def fit(self, train_data: InputData):

return operation_implementation

def predict(self, trained_operation, predict_data: InputData) -> OutputData:
raise NotImplementedError()


class BoostingClassificationStrategy(BoostingStrategy):
def __init__(self, operation_type: str, params: Optional[OperationParameters] = None):
super().__init__(operation_type, params)

def predict(self, trained_operation, predict_data: InputData) -> OutputData:
if self.output_mode in ['default', 'labels']:
prediction = trained_operation.predict(predict_data)

elif self.output_mode in ['probs', 'full_probs'] and predict_data.task:
elif self.output_mode in ['probs', 'full_probs'] and predict_data.task == 'classification':
prediction = trained_operation.predict_proba(predict_data)

else:
raise ValueError(f'Output model {self.output_mode} is not supported')
raise ValueError(f'Output mode {self.output_mode} is not supported')

return self._convert_to_output(prediction, predict_data)


class BoostingRegressionStrategy(EvaluationStrategy):
__operations_by_types = {
'catboostreg': FedotCatBoostRegressionImplementation
}

class BoostingRegressionStrategy(BoostingStrategy):
def __init__(self, operation_type: str, params: Optional[OperationParameters] = None):
self.operation_impl = self._convert_to_operation(operation_type)
super().__init__(operation_type, params)

def _convert_to_operation(self, operation_type: str):
if operation_type in self.__operations_by_types.keys():
return self.__operations_by_types[operation_type]

else:
raise ValueError(f'Impossible to obtain Boosting Strategy for {operation_type}')

def fit(self, train_data: InputData):
operation_implementation = self.operation_impl(self.params_for_fit)

with ImplementationRandomStateHandler(implementation=operation_implementation):
operation_implementation.fit(train_data)

return operation_implementation

def predict(self, trained_operation, predict_data: InputData) -> OutputData:
prediction = trained_operation.predict(predict_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class FedotCatBoostImplementation(ModelImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)

# TODO: Adding checking params compatibility with each other
self.params.update(**self.check_and_update_params(self.params.to_dict()))
self.check_and_update_params()

self.model_params = {k: v for k, v in self.params.to_dict().items() if k not in self.__operation_params}
self.model = None
Expand Down Expand Up @@ -48,15 +47,16 @@ def predict(self, input_data: InputData):

return prediction

@staticmethod
def check_and_update_params(params):
params['thread_count'] = params['n_jobs']
def check_and_update_params(self):
n_jobs = self.params.get('n_jobs')
self.params.update({'thread_count': n_jobs})

if params['use_best_model'] or params['early_stopping_rounds'] and not params['use_eval_set']:
params['use_best_model'] = False
params['early_stopping_rounds'] = False
use_best_model = self.params.get('use_best_model')
early_stopping_rounds = self.params.get('early_stopping_rounds')
use_eval_set = self.params.get('use_eval_set')

return params
if use_best_model or early_stopping_rounds and not use_eval_set:
self.params.update(dict(use_best_model=False, early_stopping_rounds=False))

@staticmethod
def convert_to_pool(data: Optional[InputData]):
Expand Down
12 changes: 6 additions & 6 deletions fedot/core/pipelines/tuning/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,12 @@ def get_parameters_dict(self):
'type': 'continuous'
},
'max_depth': {
'hyperopt-dist': hp.loguniform,
'sampling-scope': [1, 100],
'hyperopt-dist': hp.uniformint,
'sampling-scope': [4, 10],
'type': 'discrete'
},
'max_leaves': {
'hyperopt-dist': hp.loguniform,
'hyperopt-dist': hp.uniformint,
'sampling-scope': [1, 100],
'type': 'discrete'
},
Expand Down Expand Up @@ -681,12 +681,12 @@ def get_parameters_dict(self):
'type': 'continuous'
},
'max_depth': {
'hyperopt-dist': hp.loguniform,
'sampling-scope': [1, 100],
'hyperopt-dist': hp.uniformint,
'sampling-scope': [4, 10],
'type': 'discrete'
},
'max_leaves': {
'hyperopt-dist': hp.loguniform,
'hyperopt-dist': hp.uniformint,
'sampling-scope': [1, 100],
'type': 'discrete'
},
Expand Down

0 comments on commit be3faab

Please sign in to comment.