Skip to content

Commit

Permalink
Use "unobserved" for imputed variable suffixes instead of "missing"
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 30, 2023
1 parent e1fd175 commit d077ee2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def make_obs_var(
self.observed_RVs.append(observed_rv)

# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_missing", transform=transform)
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)

# Register Deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
Expand Down
4 changes: 2 additions & 2 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ def test_missing_data_model(self):
)

# make sure that data is really missing
assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_dict = {
"posterior": ["x", "y_missing"],
"posterior": ["x", "y_unobserved"],
"observed_data": ["y_observed"],
"log_likelihood": ["y_observed"],
}
Expand Down
56 changes: 32 additions & 24 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,13 @@ def test_missing_data(self):
gf = m.logp_dlogp_function()
gf._extra_are_set = True

assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type
assert m["x2_unobserved"].type == gf._extra_vars_shared["x2_unobserved"].type

# The dtype of the merged observed/missing deterministic should match the RV dtype
assert m.deterministics[0].type.dtype == x2.type.dtype

point = m.initial_point(random_seed=None).copy()
del point["x2_missing"]
del point["x2_unobserved"]

res = [gf(DictToArrayBijection.map(Point(point, model=m))) for i in range(5)]

Expand Down Expand Up @@ -1221,7 +1221,7 @@ def test_missing_basic(self, missing_data):
with pytest.warns(ImputationWarning):
_ = pm.Normal("y", x, 1, observed=missing_data)

assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_point = model.initial_point()
assert not np.isnan(model.compile_logp()(test_point))
Expand All @@ -1238,7 +1238,7 @@ def test_missing_with_predictors(self):
with pytest.warns(ImputationWarning):
y = pm.Normal("y", x * predictors, 1, observed=data)

assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_point = model.initial_point()
assert not np.isnan(model.compile_logp()(test_point))
Expand Down Expand Up @@ -1278,17 +1278,19 @@ def test_interval_missing_observations(self):
with pytest.warns(ImputationWarning):
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2)

assert isinstance(model.rvs_to_transforms[model["theta1_missing"]], IntervalTransform)
assert isinstance(
model.rvs_to_transforms[model["theta1_unobserved"]], IntervalTransform
)
assert model.rvs_to_transforms[model["theta1_observed"]] is None

prior_trace = pm.sample_prior_predictive(random_seed=rng, return_inferencedata=False)
assert set(prior_trace.keys()) == {
"theta1",
"theta1_observed",
"theta1_missing",
"theta1_unobserved",
"theta2",
"theta2_observed",
"theta2_missing",
"theta2_unobserved",
}

# Make sure the observed + missing combined deterministics have the
Expand All @@ -1303,14 +1305,16 @@ def test_interval_missing_observations(self):
# Make sure the missing parts of the combined deterministic matches the
# sampled missing and observed variable values
assert (
np.mean(prior_trace["theta1"][:, obs1.mask] - prior_trace["theta1_missing"]) == 0.0
np.mean(prior_trace["theta1"][:, obs1.mask] - prior_trace["theta1_unobserved"])
== 0.0
)
assert (
np.mean(prior_trace["theta1"][:, ~obs1.mask] - prior_trace["theta1_observed"])
== 0.0
)
assert (
np.mean(prior_trace["theta2"][:, obs2.mask] - prior_trace["theta2_missing"]) == 0.0
np.mean(prior_trace["theta2"][:, obs2.mask] - prior_trace["theta2_unobserved"])
== 0.0
)
assert (
np.mean(prior_trace["theta2"][:, ~obs2.mask] - prior_trace["theta2_observed"])
Expand All @@ -1326,18 +1330,22 @@ def test_interval_missing_observations(self):
)
assert set(trace.varnames) == {
"theta1",
"theta1_missing",
"theta1_missing_interval__",
"theta1_unobserved",
"theta1_unobserved_interval__",
"theta2",
"theta2_missing",
"theta2_unobserved",
}

# Make sure that the missing values are newly generated samples and that
# the observed and deterministic match
assert np.all(0 < trace["theta1_missing"].mean(0))
assert np.all(0 < trace["theta2_missing"].mean(0))
assert np.isclose(np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_missing"]), 0)
assert np.isclose(np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_missing"]), 0)
assert np.all(0 < trace["theta1_unobserved"].mean(0))
assert np.all(0 < trace["theta2_unobserved"].mean(0))
assert np.isclose(
np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_unobserved"]), 0
)
assert np.isclose(
np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_unobserved"]), 0
)

# Make sure that the observed values are unchanged
assert np.allclose(np.var(trace["theta1"][:, ~obs1.mask], 0), 0.0)
Expand Down Expand Up @@ -1394,7 +1402,7 @@ def test_missing_logp2(self):
"theta2", mu=theta1, observed=np.array([np.nan, np.nan, 2, np.nan, 4])
)
m_missing_logp = m_missing.compile_logp()(
{"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]}
{"theta1_unobserved": [2, 4], "theta2_unobserved": [0, 1, 3]}
)

assert m_logp == m_missing_logp
Expand All @@ -1407,15 +1415,15 @@ def test_missing_multivariate_separable(self):
a=[1, 2, 3],
observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]),
)
assert (m_miss["x_missing"].owner.op, pm.Dirichlet)
assert (m_miss["x_unobserved"].owner.op, pm.Dirichlet)
assert (m_miss["x_observed"].owner.op, pm.Dirichlet)

with pm.Model() as m_unobs:
x = pm.Dirichlet("x", a=[1, 2, 3], shape=(1, 3))

inp_vals = simplex.forward(np.array([[0.3, 0.3, 0.4]])).eval()
np.testing.assert_allclose(
m_miss.compile_logp(jacobian=False)({"x_missing_simplex__": inp_vals}),
m_miss.compile_logp(jacobian=False)({"x_unobserved_simplex__": inp_vals}),
m_unobs.compile_logp(jacobian=False)({"x_simplex__": inp_vals}) * 2,
)

Expand All @@ -1428,12 +1436,12 @@ def test_missing_multivariate_unseparable(self):
observed=np.array([[0.3, 0.3, np.nan], [np.nan, np.nan, 0.4]]),
)

assert isinstance(m_miss["x_missing"].owner.op, PartialObservedRV)
assert isinstance(m_miss["x_unobserved"].owner.op, PartialObservedRV)
assert isinstance(m_miss["x_observed"].owner.op, PartialObservedRV)

inp_values = np.array([0.3, 0.3, 0.4])
np.testing.assert_allclose(
m_miss.compile_logp()({"x_missing": [0.4, 0.3, 0.3]}),
m_miss.compile_logp()({"x_unobserved": [0.4, 0.3, 0.3]}),
st.dirichlet.logpdf(inp_values, [1, 2, 3]) * 2,
)

Expand All @@ -1451,7 +1459,7 @@ def test_missing_vector_parameter(self):
assert np.all(x_draws[:, 0] < 0)
assert np.all(x_draws[:, 1] > 0)
assert np.isclose(
m.compile_logp()({"x_missing": np.array([-10, 10, -10, 10])}),
m.compile_logp()({"x_unobserved": np.array([-10, 10, -10, 10])}),
st.norm(scale=0.1).logpdf(0) * 6,
)

Expand All @@ -1470,7 +1478,7 @@ def test_missing_symmetric(self):
x_obs_rv = m["x_observed"]
x_obs_vv = m.rvs_to_values[x_obs_rv]

x_unobs_rv = m["x_missing"]
x_unobs_rv = m["x_unobserved"]
x_unobs_vv = m.rvs_to_values[x_unobs_rv]

logp = transformed_conditional_logp(
Expand Down Expand Up @@ -1506,7 +1514,7 @@ def test_symbolic_random_variable(self):
observed=data,
)
np.testing.assert_almost_equal(
model.compile_logp()({"x_missing": [0] * 3}),
model.compile_logp()({"x_unobserved": [0] * 3}),
st.norm.logcdf(0) * 10,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def model_with_imputations():

compute_graph = {
"a": set(),
"L_missing": {"a"},
"L_unobserved": {"a"},
"L_observed": {"a"},
"L": {"L_missing", "L_observed"},
"L": {"L_unobserved", "L_observed"},
}
plates = {
"": {"a"},
"2": {"L_missing"},
"2": {"L_unobserved"},
"10": {"L_observed"},
"12": {"L"},
}
Expand Down
6 changes: 3 additions & 3 deletions tests/tuning/test_starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def test_find_MAP_issue_4488():
y = pm.Deterministic("y", x + 1)
map_estimate = find_MAP()

assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
assert not set.difference({"x_unobserved", "x_unobserved_log__", "y"}, set(map_estimate.keys()))
np.testing.assert_allclose(map_estimate["x_unobserved"], 0.2, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_unobserved"][0] + 1])


def test_find_MAP_warning_non_free_RVs():
Expand Down

0 comments on commit d077ee2

Please sign in to comment.