From 49066bfffe8824272c44538e05c033133c62822a Mon Sep 17 00:00:00 2001 From: Ujjwal Bhatnagar <52956427+Ujjwalb12@users.noreply.github.com> Date: Fri, 20 Aug 2021 08:49:58 +0000 Subject: [PATCH] forecasting.ipynb - initial code, final output in coming weeks. --- .replit | 2 + gs_quant/timeseries/__init__.py | 1 + gs_quant/timeseries/analysis.py | 8 ++- gs_quant/timeseries/forecasting.ipynb | 99 +++++++++++++++++++++++++++ setup.py | 11 ++- 5 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 .replit create mode 100644 gs_quant/timeseries/forecasting.ipynb diff --git a/.replit b/.replit new file mode 100644 index 00000000..1b15c054 --- /dev/null +++ b/.replit @@ -0,0 +1,2 @@ +language = "python3" +run = "" \ No newline at end of file diff --git a/gs_quant/timeseries/__init__.py b/gs_quant/timeseries/__init__.py index deac6027..0c805276 100644 --- a/gs_quant/timeseries/__init__.py +++ b/gs_quant/timeseries/__init__.py @@ -31,5 +31,6 @@ from .measures_xccy import * from .measures_fx_vol import * from .helper import * +from .kats import * __name__ = 'timeseries' diff --git a/gs_quant/timeseries/analysis.py b/gs_quant/timeseries/analysis.py index 38451e4a..536216ba 100644 --- a/gs_quant/timeseries/analysis.py +++ b/gs_quant/timeseries/analysis.py @@ -19,7 +19,6 @@ from gs_quant.datetime import relative_date_add from gs_quant.timeseries.datetime import * from .helper import plot_function - """ Timeseries analysis library contains functions used to analyze properties of timeseries, including laging, differencing, autocorrelation, co-integration and other operations @@ -196,7 +195,9 @@ class LagMode(Enum): @plot_function -def lag(x: pd.Series, obs: Union[Window, int, str] = 1, mode: LagMode = LagMode.EXTEND) -> pd.Series: +def lag(x: pd.Series, + obs: Union[Window, int, str] = 1, + mode: LagMode = LagMode.EXTEND) -> pd.Series: """ Lag timeseries by a number of observations or a relative date. @@ -248,7 +249,8 @@ def lag(x: pd.Series, obs: Union[Window, int, str] = 1, mode: LagMode = LagMode. # Determine how we want to handle observations prior to start date if mode == LagMode.EXTEND: if x.index.resolution != 'day': - raise MqValueError(f'unable to extend index with resolution {x.index.resolution}') + raise MqValueError( + f'unable to extend index with resolution {x.index.resolution}') kwargs = {'periods': abs(obs) + 1, 'freq': 'D'} if obs > 0: kwargs['start'] = x.index[-1] diff --git a/gs_quant/timeseries/forecasting.ipynb b/gs_quant/timeseries/forecasting.ipynb new file mode 100644 index 00000000..94302555 --- /dev/null +++ b/gs_quant/timeseries/forecasting.ipynb @@ -0,0 +1,99 @@ +# Copyright 2018 Goldman Sachs. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# +# Marquee Plot Service will attempt to make public functions (not prefixed with _) from this module available. +# Such functions should be fully documented: docstrings should describe parameters and the return value, and provide +# a 1-line description. Type annotations should be provided for parameters. +# +#Kats (Kits to Analyze Time Series) is a light-weight, easy-to-use,extenable, and generalizable framework to perform time series analysis in Python. Time series analysis is an essential component of data science and engineering work. + + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from kats.consts import TimeSeriesData +from gs_quant.datetime import relative_date_add +from gs_quant.timeseries.datetime import * +from .helper import plot_function + + + + +# Note: If the column holding the time values is not called time, you will want to specify the name of this column. +gs_quant.columns = ["time", "value"] +ts = TimeSeriesData(gs_quant) + +from kats.models.sarima import SARIMAModel, SARIMAParams + + +# create SARIMA param class +params = SARIMAParams( + p = 2, + d=1, + q=1, + trend = 'cti', + seasonal_order=(1,0,1,12) + ) + +# initiate SARIMA model +m = SARIMAModel(data=ts, params=params) + +# fit SARIMA model +m.fit() + +# generate forecast values +fcst = m.predict( + steps=30, + freq="MS" + ) + +# make plot to visualize +m.plot() + + +# import the param and model classes for Prophet model +from kats.models.prophet import ProphetModel, ProphetParams + +# create a model param instance +params = ProphetParams(seasonality_mode='multiplicative') # additive mode gives worse results + +# create a prophet model instance +m = ProphetModel(ts, params) + +# fit model simply by calling m.fit() +m.fit() + +# make prediction for next 30 month +fcst = m.predict(steps=30, freq="MS") + +# plot to visualize +m.plot() + +from kats.models.holtwinters import HoltWintersParams, HoltWintersModel + + +params = HoltWintersParams( + trend="add", + #damped=False, + seasonal="mul", + seasonal_periods=12, + ) +m = HoltWintersModel( + data=ts, + params=params) + +m.fit() + +fcst = m.predict(steps=20, alpha = 0.3) +m.plot() \ No newline at end of file diff --git a/setup.py b/setup.py index 98501554..4836b3ef 100644 --- a/setup.py +++ b/setup.py @@ -85,9 +85,14 @@ "internal": ["gs_quant_internal>=1.1.30", "requests_kerberos"], "turbo": ["quant-extensions"], "notebook": ["jupyter", "matplotlib~=3.1.0", "seaborn", "treelib"], - "test": ["pytest", "pytest-cov", "pytest-mock", "testfixtures", "nbconvert", "nbformat", "jupyter_client"], - "develop": ["wheel", "sphinx", "sphinx_rtd_theme", "sphinx_autodoc_typehints", "pytest", "pytest-cov", - "pytest-mock", "testfixtures"] + "test": [ + "pytest", "pytest-cov", "pytest-mock", "testfixtures", "nbconvert", + "nbformat", "jupyter_client" + ], + "develop": [ + "wheel", "sphinx", "sphinx_rtd_theme", "sphinx_autodoc_typehints", + "pytest", "pytest-cov", "pytest-mock", "testfixtures" + ] }, classifiers=[ "Programming Language :: Python :: 3",