From 8382e1d06dcb815024eaf1a0bd765a9fac80d83a Mon Sep 17 00:00:00 2001 From: Lopa10ko Date: Fri, 17 May 2024 00:56:14 +0300 Subject: [PATCH] feat: add convergence warning interception, defaulting to additive ets --- .../models/ts_implementations/statsmodels.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/fedot/core/operations/evaluation/operation_implementations/models/ts_implementations/statsmodels.py b/fedot/core/operations/evaluation/operation_implementations/models/ts_implementations/statsmodels.py index ee7be042a0..be57649372 100644 --- a/fedot/core/operations/evaluation/operation_implementations/models/ts_implementations/statsmodels.py +++ b/fedot/core/operations/evaluation/operation_implementations/models/ts_implementations/statsmodels.py @@ -1,3 +1,4 @@ +import warnings from copy import copy import numpy as np @@ -5,6 +6,7 @@ from statsmodels.genmod.families import Gamma, Gaussian, InverseGaussian from statsmodels.genmod.families.links import identity, inverse_power, inverse_squared, log as lg from statsmodels.genmod.generalized_linear_model import GLM +from statsmodels.tools.sm_exceptions import ConvergenceWarning from statsmodels.tsa.ar_model import AutoReg from statsmodels.tsa.exponential_smoothing.ets import ETSModel @@ -255,6 +257,14 @@ def __init__(self, params: OperationParameters): else: self.seasonal_periods = None + def _init_model(self, endog: np.ndarray): + self.model = ETSModel(endog=endog, + error=self.params.get('error'), + trend=self.params.get('trend'), + seasonal=self.params.get('seasonal'), + damped_trend=self.params.get('damped_trend') if self.params.get('trend') else None, + seasonal_periods=self.seasonal_periods) + def fit(self, input_data): endog = input_data.features.astype('float64') @@ -262,15 +272,18 @@ def fit(self, input_data): if self._check_and_correct_params(endog): self.log.info(f'Changed the following ETSModel parameters: {self.params.changed_parameters}') - self.model = ETSModel( - endog=endog, - error=self.params.get('error'), - trend=self.params.get('trend'), - seasonal=self.params.get('seasonal'), - damped_trend=self.params.get('damped_trend') if self.params.get('trend') else None, - seasonal_periods=self.seasonal_periods - ) - self.model = self.model.fit(disp=False) + try: + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=ConvergenceWarning) + self._init_model(endog) + # if convergence warning is caught, switch to default ETSModel + self.model.fit(disp=False) + except ConvergenceWarning as e: + self.params.update(**{'error': 'add', 'trend': None, 'seasonal': None}) + self._init_model(endog) + self.log.info(f'Switched to default ETSModel due to a convergence warning: {e}') + finally: + self.model = self.model.fit(disp=False) return self.model def predict(self, input_data): @@ -326,13 +339,13 @@ def _check_and_correct_params(self, endog: np.ndarray) -> bool: if np.any(endog <= 0): for component in ets_components: if self.params.get(component) == 'mul': - self.params.update(**{f'{component}': 'add'}) + self.params.update(**{component: 'add'}) params_changed = True if self.params.get('trend') == 'mul' \ and self.params.get('damped_trend') \ and not self.params.get('seasonal'): - self.params.update(trend='add') + self.params.update(**{'trend': 'add'}) params_changed = True return params_changed