From 0bf8d30125699144384b37b853dbbe31418e46a2 Mon Sep 17 00:00:00 2001 From: Andrey Stebenkov Date: Tue, 5 Sep 2023 16:33:20 +0300 Subject: [PATCH] saving/loading catboost model --- .../models/boostings_implementations.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fedot/core/operations/evaluation/operation_implementations/models/boostings_implementations.py b/fedot/core/operations/evaluation/operation_implementations/models/boostings_implementations.py index 065e7d7a7b..07bb511134 100644 --- a/fedot/core/operations/evaluation/operation_implementations/models/boostings_implementations.py +++ b/fedot/core/operations/evaluation/operation_implementations/models/boostings_implementations.py @@ -1,3 +1,4 @@ +import os from typing import Optional import pandas as pd @@ -8,6 +9,7 @@ from fedot.core.data.data_split import train_test_data_setup from fedot.core.operations.evaluation.operation_implementations.implementation_interfaces import ModelImplementation from fedot.core.operations.operation_parameters import OperationParameters +from fedot.core.utils import default_fedot_data_dir class FedotCatBoostImplementation(ModelImplementation): @@ -73,6 +75,14 @@ def convert_to_pool(data: Optional[InputData]): feature_names=data.features_names.tolist() ) + def save_model(self, model_name: str = 'catboost'): + save_path = os.path.join(default_fedot_data_dir(), f'catboost/{model_name}.cbm') + self.model.save_model(save_path, format='cbm') + + def load_model(self, path): + self.model = CatBoostClassifier() + self.model.load_model(path) + class FedotCatBoostClassificationImplementation(FedotCatBoostImplementation): def __init__(self, params: Optional[OperationParameters] = None):