Skip to content

Commit

Permalink
simplified model in TestModelCopy
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekermanjian committed Oct 2, 2024
1 parent fb00f85 commit d057a9d
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,16 +1768,15 @@ class TestModelCopy:
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
with pm.Model() as simple_model:
error = pm.HalfNormal("error", 0.5)
alpha = pm.Normal("alpha", 0, 1)
pm.Normal("y", alpha, error)
pm.Normal("y")

copy_simple_model = copy_method(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)

with copy_simple_model:
z = pm.Deterministic("z", copy_simple_model["y"] + 1)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)
Expand All @@ -1787,17 +1786,11 @@ def test_copy_model(self, copy_method) -> None:
== copy_simple_model_prior_predictive["prior"]["y"].values
)

with copy_simple_model:
z = pm.Deterministic("z", copy_simple_model["alpha"] + 1)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

assert "z" in copy_simple_model.named_vars
assert "z" not in simple_model.named_vars
assert (
copy_simple_model_prior_predictive["prior"]["z"].values
== 1 + copy_simple_model_prior_predictive["prior"]["alpha"].values
== 1 + simple_model_prior_predictive["prior"]["y"].values
)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
Expand Down

0 comments on commit d057a9d

Please sign in to comment.