Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sum of models and custom models #160

Merged
merged 31 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7232c43
Add _custom_model class attribute, exception in fit and p0 and bounds…
rhugonnet Aug 11, 2023
5247cd4
Add tests on custom model definition and fitting
rhugonnet Aug 11, 2023
7b3cd22
Add to DirectionalVariogram and rename _custom_model to _is_model_cus…
rhugonnet Aug 11, 2023
20749fc
Build sum of models from string name
rhugonnet Aug 21, 2023
c93ed9a
Add self._model_name attribute, and logic for sum of models in descri…
rhugonnet Aug 21, 2023
daf6eda
Add more tests and correct bugs
rhugonnet Aug 21, 2023
2df8592
Remove static typing, move to description
rhugonnet Aug 21, 2023
5d37015
Fix instantiation and fix tests
rhugonnet Aug 21, 2023
66aff87
Add tests for plotting
rhugonnet Aug 21, 2023
82fa559
Document fit_bounds and fit_p0, and add tests
rhugonnet Aug 22, 2023
4185c75
Update model argument description
rhugonnet Aug 22, 2023
bc4e170
Fix random test and inf bound check
rhugonnet Aug 22, 2023
aa2d7ed
Fix warnings in test_variogram
rhugonnet Aug 22, 2023
ce9aa85
Fix typo in description
rhugonnet Aug 22, 2023
de3a948
Fix floating precision error
rhugonnet Sep 21, 2023
2e31741
Add example in user guide
rhugonnet Sep 21, 2023
6985339
Make nugget consistent for a sum of models
rhugonnet Sep 21, 2023
fbff110
Finalize nugget for sum and add tests
rhugonnet Sep 21, 2023
f4d33f1
Rerun tests with scipy 1.11.3
rhugonnet Oct 2, 2023
34388c6
Force SciPy versions to before 1.11.1
rhugonnet Oct 5, 2023
d357c1a
Try with NumPy version before 1.25
rhugonnet Oct 5, 2023
56127e2
Try with 1.24
rhugonnet Oct 5, 2023
7d5d287
Try 1.24.1
rhugonnet Oct 5, 2023
591ae37
Try 1.24.1
rhugonnet Oct 5, 2023
d0bb82d
Try 1.24.2
rhugonnet Oct 5, 2023
9a76ae7
Try 1.24.3
rhugonnet Oct 5, 2023
35903f8
Try 1.25.0
rhugonnet Oct 5, 2023
9ada1f4
Remove NumPy version fixing
rhugonnet Oct 9, 2023
59781f1
Change stable entropy bin test precision to 0 decimals
rhugonnet Oct 9, 2023
5f67e4e
Merge branch 'main' into combine_models
mmaelicke Oct 13, 2023
9690d39
Linting
rhugonnet Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skgstat/DirectionalVariogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def __init__(self,

# set the directional model
self._directional_model = None
self._is_model_custom = False
self.set_directional_model(model_name=directional_model)

# the binning settings
Expand Down
41 changes: 30 additions & 11 deletions skgstat/Variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Variogram class
"""
import copy
import inspect
import warnings
from typing import Iterable, Callable, Union, Tuple

Expand Down Expand Up @@ -335,6 +336,7 @@ def __init__(self,

# model can be a function or a string
self._model = None
self._is_model_custom = False
self.set_model(model_name=model)

# specify if the lag should be given absolute or relative to the maxlag
Expand Down Expand Up @@ -973,6 +975,7 @@ def set_model(self, model_name):
) % model_name
)
else: # pragma: no cover
self._is_model_custom = True
self._model = model_name

def _build_harmonized_model(self):
Expand Down Expand Up @@ -1454,7 +1457,7 @@ def preprocessing(self, force=False):
self._calc_diff(force=force)
self._calc_groups(force=force)

def fit(self, force=False, method=None, sigma=None, **kwargs):
def fit(self, force=False, method=None, sigma=None, bounds=None, p0=None, **kwargs):
rhugonnet marked this conversation as resolved.
Show resolved Hide resolved
"""Fit the variogram

The fit function will fit the theoretical variogram function to the
Expand Down Expand Up @@ -1563,18 +1566,34 @@ def fit(self, force=False, method=None, sigma=None, **kwargs):
self.cof = [r, s, n]
return

# Switch the method
# wrap the model to include or exclude the nugget
if self.use_nugget:
def wrapped(*args):
return self._model(*args)
# For a supported model, wrap the function depending on nugget and get logical bounds
if not self._is_model_custom:
# Switch the method
# wrap the model to include or exclude the nugget
if self.use_nugget:
def wrapped(*args):
return self._model(*args)
else:
def wrapped(*args):
return self._model(*args, 0)

# get p0
if bounds is None:
bounds = (0, self.__get_fit_bounds(x, y))
if p0 is None:
p0 = np.asarray(bounds[1])
# Else, inspect the function for the number of arguments
else:
def wrapped(*args):
return self._model(*args, 0)
# The number of arguments of argspec minus one is what we initialized
argspec = inspect.getfullargspec(self._model)
nb_args = len(argspec) - 1
if bounds is None:
bounds = (0, [np.maximum(np.nanmax(x), np.nanmax(y))] * nb_args)
if p0 is None:
p0 = np.asarray(bounds[1])

# get p0
bounds = (0, self.__get_fit_bounds(x, y))
p0 = np.asarray(bounds[1])
def wrapped(*args):
return self._model(*args)

# Trust Region Reflective
if self.fit_method == 'trf':
Expand Down
15 changes: 15 additions & 0 deletions skgstat/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def adder(l, a):
for r, c in zip(res, adder([1, 4, 8], 4)):
self.assertEqual(r, c)

def test_sum_spherical(self):
@variogram
def sum_spherical(h, r1, c1, r2, c2, b1=0, b2=0):
return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2)

# Parameters for the two spherical models
params = [1, 0.3, 10, 0.7]

# Values at which we'll evaluate the function and its expected result
vals = [0, 1, 100]
res = [0, 0.3 + spherical(1, 10, 0.7), 1]

for r, c in zip(res, sum_spherical(vals, *params)):
self.assertEqual(r, c)


if __name__=='__main__':
unittest.main()
15 changes: 14 additions & 1 deletion skgstat/tests/test_variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from skgstat import OrdinaryKriging
from skgstat import estimators
from skgstat import plotting
from skgstat.models import variogram, spherical, matern


class TestSpatiallyCorrelatedData(unittest.TestCase):
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_sparse_maxlag_30(self):
self.assertAlmostEqual(x, y, places=3)


class TestVariogramInstatiation(unittest.TestCase):
class TestVariogramInstantiation(unittest.TestCase):
def setUp(self):
# set up default values, whenever c and v are not important
np.random.seed(42)
Expand Down Expand Up @@ -949,6 +950,18 @@ def test_implicit_nugget(self):

self.assertTrue(abs(V.parameters[-1] - 2.) < 1e-10)

def test_fit_custom_model(self):

# Define a custom variogram and run the fit
@variogram
def sum_spherical(h, r1, c1, r2, c2, b1, b2):
return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2)

V = Variogram(self.c, self.v, use_nugget=True, model=sum_spherical)

# Check that 6 parameters were found
assert len(V.cof) == 6


class TestVariogramQualityMeasures(unittest.TestCase):
def setUp(self):
Expand Down
Loading