Skip to content

Commit

Permalink
Expression: test evaluate method (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause authored Oct 1, 2024
1 parent 660c45c commit 72d2fd9
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mesmer/mesmer_x/train_utils_mesmerx.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def evaluate(self, coefficients_values, inputs_values, forced_shape=None):
# - use broadcasting
# - can we avoid using exec & eval?
# - only parse the values once? (to avoid doing it repeatedly)
# - don't allow list of coefficients_values
# - require list of coefficients_values (similar to minimize)?
# - convert dataset to numpy arrays?

# Check 1: are all the coefficients provided?
if isinstance(coefficients_values, dict | xr.Dataset):
Expand All @@ -351,7 +352,7 @@ def evaluate(self, coefficients_values, inputs_values, forced_shape=None):
# Check 3: do the inputs have the same shape
shapes = {inputs_values[i].shape for i in self.inputs_list}
if len(shapes) > 1:
raise ValueError("Different shapes of inputs detected.")
raise ValueError("shapes of inputs must be equal")

# Evaluation 1: coefficients
for c in coefficients_values:
Expand Down
135 changes: 135 additions & 0 deletions tests/unit/test_mesmer_x_expression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import scipy as sp
import xarray as xr

import mesmer
from mesmer.mesmer_x import Expression

inf = float("inf")
Expand Down Expand Up @@ -213,3 +216,135 @@ def test_expression_covariate_substring():

coeffs_per_param = {"loc": ["c1", "c2", "c3"], "scale": ["c4"]}
assert expr.coefficients_dict == coeffs_per_param


def test_evaluate_missing_coefficient_dict():

expr = Expression("norm(loc=c1, scale=c2)", expr_name="name")

with pytest.raises(
ValueError, match="Missing information for the coefficient: 'c1'"
):
expr.evaluate({}, {})

with pytest.raises(
ValueError, match="Missing information for the coefficient: 'c2'"
):
expr.evaluate({"c1": 1}, {})


def test_evaluate_missing_coefficient_dataset():

expr = Expression("norm(loc=c1, scale=c2)", expr_name="name")

with pytest.raises(
ValueError, match="Missing information for the coefficient: 'c1'"
):
expr.evaluate(xr.Dataset(), {})

with pytest.raises(
ValueError, match="Missing information for the coefficient: 'c2'"
):
expr.evaluate(xr.Dataset(data_vars={"c1": 1}), {})


def test_evaluate_missing_coefficient_list():

expr = Expression("norm(loc=c1, scale=c2)", expr_name="name")

with pytest.raises(
ValueError, match="Inconsistent information for the coefficients_values"
):
expr.evaluate([], {})

with pytest.raises(
ValueError, match="Inconsistent information for the coefficients_values"
):
expr.evaluate([1], {})


def test_evaluate_missing_covariates_dict():

expr = Expression("norm(loc=c1 * __T__, scale=c2 * __F__)", expr_name="name")

with pytest.raises(ValueError, match="Missing information for the input: 'T'"):
expr.evaluate([1, 1], {})

with pytest.raises(ValueError, match="Missing information for the input: 'F'"):
expr.evaluate([1, 1], {"T": 1})


def test_evaluate_missing_covariates_ds():

expr = Expression("norm(loc=c1 * __T__, scale=c2 * __F__)", expr_name="name")

with pytest.raises(ValueError, match="Missing information for the input: 'T'"):
expr.evaluate([1, 1], xr.Dataset())

with pytest.raises(ValueError, match="Missing information for the input: 'F'"):
expr.evaluate([1, 1], xr.Dataset(data_vars={"T": 1}))


def test_evaluate_covariates_wrong_shape():

expr = Expression("norm(loc=c1 * __T__, scale=c2 * __F__)", expr_name="name")

T = np.array([1])
F = np.array([1, 1])
data_vars = {"T": T, "F": F}

with pytest.raises(ValueError, match="shapes of inputs must be equal"):
expr.evaluate([1, 1], data_vars)

with pytest.raises(ValueError, match="shapes of inputs must be equal"):
expr.evaluate([1, 1], xr.Dataset(data_vars=data_vars))


def test_evaluate_norm():

expr = Expression("norm(loc=c1 * __T__, scale=c2)", expr_name="name")
dist = expr.evaluate([1, 2], {"T": np.array([1, 2])})

assert isinstance(dist.dist, type(sp.stats.norm))

expected = {"loc": np.array([1, 2]), "scale": np.array([2.0, 2.0])}

# assert frozen params are equal
mesmer.testing.assert_dict_allclose(dist.kwds, expected)

# NOTE: will write own function to return param values
mesmer.testing.assert_dict_allclose(dist.kwds, expr.parameters_values)

# a second set of values
dist = expr.evaluate([2, 1], {"T": np.array([2, 5])})

expected = {"loc": np.array([4, 10]), "scale": np.array([1.0, 1.0])}

# assert frozen params are equal
mesmer.testing.assert_dict_allclose(dist.kwds, expected)

mesmer.testing.assert_dict_allclose(dist.kwds, expr.parameters_values)


def test_evaluate_norm_dataset():
# NOTE: not sure if passing DataArray to scipy distribution is a good idea

expr = Expression("norm(loc=c1 * __T__, scale=c2)", expr_name="name")

coefficients_values = xr.Dataset(data_vars={"c1": 1, "c2": 2})
inputs_values = xr.Dataset(data_vars={"T": ("x", np.array([1, 2]))})

dist = expr.evaluate(coefficients_values, inputs_values)

assert isinstance(dist.dist, type(sp.stats.norm))

loc = xr.DataArray([1, 2], dims="x")
scale = xr.DataArray([2, 2], dims="x")

expected = {"loc": loc, "scale": scale}

# assert frozen params are equal
mesmer.testing.assert_dict_allclose(dist.kwds, expected)

# NOTE: will write own function to return param values
mesmer.testing.assert_dict_allclose(dist.kwds, expr.parameters_values)

0 comments on commit 72d2fd9

Please sign in to comment.