Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prototype of class structure #109

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
93239ff
Sketch out prototype
znicholls Oct 21, 2021
9ae897f
Mark path which is never used to simply things
znicholls Oct 24, 2021
35b5fed
Block out another branch
znicholls Oct 24, 2021
92f1e4e
Implemente prototype method
znicholls Oct 24, 2021
47fa2c6
Tidy up
znicholls Oct 25, 2021
b408db0
Format
znicholls Oct 25, 2021
b93de58
Tidy up a bit more
znicholls Oct 25, 2021
de1d293
Change calibration pattern to simplify use of classes
znicholls Oct 25, 2021
3d97ec1
Try doing auto-regression implementation
znicholls Oct 26, 2021
14c429d
Clean up loops
znicholls Oct 27, 2021
4fd57aa
Remove pdb statement
znicholls Oct 27, 2021
1026161
Sketch out test for train lv
znicholls Nov 24, 2021
d4e4e7f
More notes about how to do train lv
znicholls Nov 25, 2021
3470ee4
Start working on geodesic functions
znicholls Nov 25, 2021
3b23a6d
Get legacy training running
znicholls Nov 26, 2021
88ec4ad
Finish reimplementing train_lv
znicholls Nov 26, 2021
592441f
Merge branch 'main' into prototype
mathause Sep 21, 2023
08a5134
linting
mathause Sep 21, 2023
fe1d1bc
fix: test_prototype_train_lv
mathause Sep 21, 2023
1b4b283
Merge branch 'main' into prototype
mathause Sep 21, 2023
c854fa4
clean train_lt.py
mathause Sep 21, 2023
c3bf5e0
remove prototype/utils.py after #298, #299, and #300
mathause Sep 21, 2023
c727a6e
Merge branch 'main' into prototype
mathause Sep 23, 2023
26f6761
allow selected ar order to be None
mathause Sep 25, 2023
a2fec9b
Merge branch 'main' into prototype
mathause Sep 25, 2023
7143934
Merge branch 'main' into prototype
mathause Dec 12, 2023
d3eb99d
fix for gaspari_cohn and geodist_exact
mathause Dec 12, 2023
8f19cd5
small refactor
mathause Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added mesmer/prototype/__init__.py
Empty file.
142 changes: 142 additions & 0 deletions mesmer/prototype/calibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import abc

import numpy as np
import sklearn.linear_model
import statsmodels.tsa.ar_model
import xarray as xr


class MesmerCalibrateBase(metaclass=abc.ABCMeta):
"""
Abstract base class for calibration
"""


class MesmerCalibrateTargetPredictor(MesmerCalibrateBase):
@abc.abstractmethod
def calibrate(self, target, predictor, **kwargs):
"""
[TODO: update this based on however LinearRegression.calibrate's docs
end up looking]
"""


class LinearRegression(MesmerCalibrateTargetPredictor):
"""
following

https://github.com/MESMER-group/mesmer/blob/d73e8f521a2e1d081a48b775ba14dd764cb671e8/mesmer/calibrate_mesmer/train_lt.py#L165

All the lines above and below that line are basically just data
preparation, which makes it very hard to see what the model actually is
"""

@staticmethod
def _regress_single_group(target_point, predictor, weights=None):
# this is the method that actually does the regression
args = [predictor.T, target_point.reshape(-1, 1)]
if weights is not None:
args.append(weights)
reg = sklearn.linear_model.LinearRegression().fit(*args)
out_array = np.concatenate([reg.intercept_, *reg.coef_])

return out_array

def calibrate(
self,
target_flattened,
predictors_flattened,
stack_coord_name,
predictor_name="predictor",
weights=None,
predictor_temporary_name="__pred_store__",
):
"""
TODO: redo docstring
"""
if predictor_name not in predictors_flattened.dims:
raise AssertionError(f"{predictor_name} not in {predictors_flattened.dims}")

if predictor_temporary_name in predictors_flattened.dims:
raise AssertionError(
f"{predictor_temporary_name} already in {predictors_flattened.dims}, choose a different temporary name"
)

res = xr.apply_ufunc(
self._regress_single_group,
target_flattened,
predictors_flattened,
input_core_dims=[[stack_coord_name], [predictor_name, stack_coord_name]],
output_core_dims=((predictor_temporary_name,),),
vectorize=True,
kwargs=dict(weights=weights),
)

# assuming that predictor's names are in the 'variable' coordinate
predictors_plus_intercept_order = ["intercept"] + list(
predictors_flattened["variable"].values
)
res = res.assign_coords(
{predictor_temporary_name: predictors_plus_intercept_order}
).rename({predictor_temporary_name: predictor_name})

return res


class MesmerCalibrateTarget(MesmerCalibrateBase):
@abc.abstractmethod
def calibrate(self, target, **kwargs):
"""
[TODO: update this based on however LinearRegression.calibrate's docs
end up looking]
"""

@staticmethod
def _check_target_is_one_dimensional(target, return_numpy_values):
if len(target.dims) > 1:
raise AssertionError(f"More than one dimension, found {target.dims}")

if not return_numpy_values:
return None

return target.dropna(dim=target.dims[0]).values


class AutoRegression1DOrderSelection(MesmerCalibrateTarget):
def calibrate(
self,
target,
maxlag=12,
ic="bic",
):
target_numpy = self._check_target_is_one_dimensional(
target, return_numpy_values=True
)

calibrated = statsmodels.tsa.ar_model.ar_select_order(
target_numpy, maxlag=maxlag, ic=ic, old_names=False
)

return calibrated.ar_lags


class AutoRegression1D(MesmerCalibrateTarget):
def calibrate(
self,
target,
order,
):
target_numpy = self._check_target_is_one_dimensional(
target, return_numpy_values=True
)

calibrated = statsmodels.tsa.ar_model.AutoReg(
target_numpy, lags=order, old_names=False
).fit()

return {
"intercept": calibrated.params[0],
"lag_coefficients": calibrated.params[1:],
# I don't know what this is so a better name could probably be chosen
"standard_innovations": np.sqrt(calibrated.sigma2),
}
Loading