Skip to content

Commit

Permalink
feat: add convergence warning interception, defaulting to additive ets
Browse files Browse the repository at this point in the history
  • Loading branch information
Lopa10ko committed May 16, 2024
1 parent 890492f commit 8382e1d
Showing 1 changed file with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from copy import copy

import numpy as np
import statsmodels.api as sm
from statsmodels.genmod.families import Gamma, Gaussian, InverseGaussian
from statsmodels.genmod.families.links import identity, inverse_power, inverse_squared, log as lg
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.tools.sm_exceptions import ConvergenceWarning
from statsmodels.tsa.ar_model import AutoReg
from statsmodels.tsa.exponential_smoothing.ets import ETSModel

Expand Down Expand Up @@ -255,22 +257,33 @@ def __init__(self, params: OperationParameters):
else:
self.seasonal_periods = None

def _init_model(self, endog: np.ndarray):
self.model = ETSModel(endog=endog,
error=self.params.get('error'),
trend=self.params.get('trend'),
seasonal=self.params.get('seasonal'),
damped_trend=self.params.get('damped_trend') if self.params.get('trend') else None,
seasonal_periods=self.seasonal_periods)

def fit(self, input_data):
endog = input_data.features.astype('float64')

# check ets params according to statsmodels restrictions
if self._check_and_correct_params(endog):
self.log.info(f'Changed the following ETSModel parameters: {self.params.changed_parameters}')

self.model = ETSModel(
endog=endog,
error=self.params.get('error'),
trend=self.params.get('trend'),
seasonal=self.params.get('seasonal'),
damped_trend=self.params.get('damped_trend') if self.params.get('trend') else None,
seasonal_periods=self.seasonal_periods
)
self.model = self.model.fit(disp=False)
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConvergenceWarning)
self._init_model(endog)
# if convergence warning is caught, switch to default ETSModel
self.model.fit(disp=False)
except ConvergenceWarning as e:
self.params.update(**{'error': 'add', 'trend': None, 'seasonal': None})
self._init_model(endog)
self.log.info(f'Switched to default ETSModel due to a convergence warning: {e}')
finally:
self.model = self.model.fit(disp=False)
return self.model

def predict(self, input_data):
Expand Down Expand Up @@ -326,13 +339,13 @@ def _check_and_correct_params(self, endog: np.ndarray) -> bool:
if np.any(endog <= 0):
for component in ets_components:
if self.params.get(component) == 'mul':
self.params.update(**{f'{component}': 'add'})
self.params.update(**{component: 'add'})
params_changed = True

if self.params.get('trend') == 'mul' \
and self.params.get('damped_trend') \
and not self.params.get('seasonal'):
self.params.update(trend='add')
self.params.update(**{'trend': 'add'})
params_changed = True

return params_changed

0 comments on commit 8382e1d

Please sign in to comment.