Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Model __new__ and metaclass #7473

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

thomasaarholt
Copy link
Contributor

@thomasaarholt thomasaarholt commented Aug 23, 2024

Description

This PR is a refactor that improves the code that creates models and keeps track of the model contexts that are active at a given moment.

I have previously done some work (#6809) on type hinting the Model class. It was a bit difficult to understand what was going on, and the code was lacking documentation that would help further development. In particular, the code creating the Model class and instance, and keeping track of active contexts was rather hacky.

I am on parental leave, which (maybe ironically) gives me the time to take a look behind the scenes and hopefully make some improvements that give back to the community. I do recognize that this PR comes a bit "unannounced", and since it doesn't fix any major issues or introduce hot features, I don't have any expectations for it to be reviewed soon (or even at all).

Each commit is a logical change with a reasonably descriptive text.

This passes all tests in tests/model.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7473.org.readthedocs.build/en/7473/

@thomasaarholt
Copy link
Contributor Author

thomasaarholt commented Aug 23, 2024

The mypy failure is because my code change allows mypy to infer the type of model in pymc/data.py:446 as Model rather than Unknown, which in turns reveals that values isn't of type Unknown on line 450. So I've just exposed something that was wasn't typed correctly.

All test failures except one of the Ubuntu ones can be attributed to this recent change in PyTensor. I am not sure what caused the remaining one, but I'm pretty confident it is unrelated.

Edit: Rebased on main branch and added one more set of type hints.

Copy link

codecov bot commented Aug 28, 2024

Codecov Report

Attention: Patch coverage is 97.50000% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.48%. Comparing base (cdcdb58) to head (74665c8).
Report is 32 commits behind head on main.

Files with missing lines Patch % Lines
pymc/data.py 88.88% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7473      +/-   ##
==========================================
+ Coverage   92.44%   92.48%   +0.04%     
==========================================
  Files         103      103              
  Lines       17119    17087      -32     
==========================================
- Hits        15825    15803      -22     
+ Misses       1294     1284      -10     
Files with missing lines Coverage Δ
pymc/model/core.py 92.79% <100.00%> (+0.96%) ⬆️
pymc/data.py 89.22% <88.88%> (+0.13%) ⬆️

@twiecki
Copy link
Member

twiecki commented Sep 6, 2024

Thanks @thomasaarholt, there are still mypy failures, do you have a handle on them or would you like some help?

@thomasaarholt
Copy link
Contributor Author

@twiecki Thanks! The remaining one is one that I'm not sure on how best to handle:
From the mypy github action:

pymc/data.py:439: error: Incompatible types in assignment (expression has type "Sequence[str | None]", variable has type "Sequence[str] | None")

This stems from the determine_coords function having a return signature for dim that is Sequence[str | None]. But the dim type hint from higher up in the def Data constructor only allows a sequence of Sequence[str]. The sequence can be of None inner type here.

Help much appreciated!

@thomasaarholt
Copy link
Contributor Author

There! @twiecki wanna take a look? I went through the logic related to the above failing mypy error and think I found a good resolution. See commit description of the last two commits.

The failing test is unrelated.

@twiecki
Copy link
Member

twiecki commented Oct 4, 2024

Any idea about that failing test?

================================== FAILURES ===================================
_________________ TestMarginalVsLatent.testLatentMultioutput __________________

self = <tests.gp.test_gp.TestMarginalVsLatent object at 0x0000029EDDB84820>

    def testLatentMultioutput(self):
        n_outputs = 2
        X = np.random.randn(20, 3)
        y = np.random.randn(n_outputs, 20)
        Xnew = np.random.randn(30, 3)
        pnew = np.random.randn(n_outputs, 30)
    
        with pm.Model() as latent_model:
            cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
            mean_func = pm.gp.mean.Constant(0.5)
            latent_gp = pm.gp.Latent(mean_func=mean_func, cov_func=cov_func)
            latent_f = latent_gp.prior("f", X, n_outputs=n_outputs, reparameterize=True)
            latent_p = latent_gp.conditional("p", Xnew)
    
        with pm.Model() as marginal_model:
            cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
            mean_func = pm.gp.mean.Constant(0.5)
            marginal_gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
            marginal_f = marginal_gp.marginal_likelihood("f", X, y, sigma=0.0)
            marginal_p = marginal_gp.conditional("p", Xnew)
    
        assert tuple(latent_f.shape.eval()) == tuple(marginal_f.shape.eval()) == y.shape
        assert tuple(latent_p.shape.eval()) == tuple(marginal_p.shape.eval()) == pnew.shape
    
        chol = np.linalg.cholesky(cov_func(X).eval())
        v = np.linalg.solve(chol, (y - 0.5).T)
        A = np.linalg.solve(chol, cov_func(X, Xnew).eval()).T
        mu_cond = mean_func(Xnew).eval() + (A @ v).T
        cov_cond = cov_func(Xnew, Xnew).eval() - A @ A.T
    
        with pm.Model() as numpy_model:
            numpy_p = pm.MvNormal.dist(mu=pt.as_tensor(mu_cond), cov=pt.as_tensor(cov_cond))
    
        latent_rv_logp = pm.logp(latent_p, pnew)
        marginal_rv_logp = pm.logp(marginal_p, pnew)
        numpy_rv_logp = pm.logp(numpy_p, pnew)
    
        assert (
            latent_rv_logp.shape.eval()
            == marginal_rv_logp.shape.eval()
            == numpy_rv_logp.shape.eval()
        )
    
>       npt.assert_allclose(latent_rv_logp.eval(), marginal_rv_logp.eval(), atol=5)

tests\gp\test_gp.py:412: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<function assert_allclose.<locals>.compare at 0x0000029EE75B9AB0>, array([-41.92778875, -45.52605[201](https://github.com/pymc-devs/pymc/actions/runs/11177316239/job/31072575657?pr=7473#step:7:202)]), array([-40.07459855, -65.84149075]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=5', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=5
E           
E           Mismatched elements: 1 / 2 (50%)
E           Max absolute difference: 20.31543874
E           Max relative difference: 0.30855071
E            x: array([-41.927789, -45.526052])
E            y: array([-40.074599, -65.841491])

C:\Miniconda3\envs\pymc-test\lib\contextlib.py:79: AssertionError

@thomasaarholt
Copy link
Contributor Author

I assume that if we reran it (maybe you can trigger that?) it would pass. Seeing as it fails only on windows and not Linux/mac, I assumed it was a result of variance due to random sampling.

@ricardoV94
Copy link
Member

I restarted it

@thomasaarholt
Copy link
Contributor Author

Ran passing everything. Ready for review!

pymc/model/core.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

@thomasaarholt the PR looks good, but since it touches on such as fundamental functionality I've asked for further reviews.

You have some conflicts that need to be resolved as well.

Thanks for the work so far!

get_context returns an instance of a Model, not a ContextMeta object
We don't need the typevar, since we don't use it for anything special
All of these are supported on python>=3.9.
We create a global instance of it within this module, which is similar
to how it worked before, where a `context_class` attribute was attached
to the Model class.

We inherit from threading.local to ensure thread safety when working
with models on multiple threads. See pymc-devs#1552 for the reasoning. This is
already tested in `test_thread_safety`.
UNSET is the instance of the _UnsetType type.
We should be typing the latter here.
We use the new ModelManager.parent_context property to reliably set any
parent context, or else set it to None.
We set this directly on the class as a classmethod, which is clearer
than going via the metaclass.
The original function does not behave as I expected.
In the following example I expected that it would return only the final
model, not root.

This method is not used anywhere in the pymc codebase, so I have dropped
it from the codebase. I originally included the following code to replace
it, but since it is not used anyway, it is better to remove it.

```python`
@classmethod
def get_contexts(cls) -> list[Model]:
    """Return a list of the currently active model contexts."""
    return MODEL_MANAGER.active_contexts
```

Example for testing behaviour in current main branch:
```python
import pymc as pm

with pm.Model(name="root") as root:
    print([c.name for c in pm.Model.get_contexts()])
    with pm.Model(name="first") as first:
        print([c.name for c in pm.Model.get_contexts()])
    with pm.Model(name="m_with_model_None", model=None) as m_with_model_None:
        # This one doesn't make much sense:
        print([c.name for c in pm.Model.get_contexts()])
```
We only keep the __call__ method, which is necessary to keep the
model context itself active during that model's __init__.
In pymc/distributions/distribution.py, this change allows the type
checker to infer that `rv_out` can only be a TensorVariable.

Thanks to @ricardoV94 for type hint on rv_var.
I originally tried numpy's ArrayLike, replacing Sequence entirely, but then I realized
that ArrayLike also allows non-sequences like integers and floats.

I am not certain if `values="a string"` should be legal. With the type hint sequence, it is.
Might be more accurate, but verbose to use `list | tuple | set | np.ndarray | None`.
…unction

We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour
set `dims=[None] * N` at the end of `determine_coords`.

To handle this, I created a `new_dims` with a larger type scope which matches
the return type of `dims` in `determine_coords`.

Then I did the same within def Data to support this new type hint.
The only case where dims=[None, ...] is when the user has passed dims=None. Since the user passed dims=None,
they shouldn't be expecting any coords to match that dimension. Thus we don't need to try to add any
more coords to the model.
Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @thomasaarholt. The original code was very hard to read, and yours is much better now. I think that your refactor could be extended to a few other places. I left comments in the relevant parts.

@@ -218,9 +218,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
def determine_coords(
model,
value: pd.DataFrame | pd.Series | xr.DataArray,
dims: Sequence[str | None] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this ever change? I thought that some variables could have dims=[None, "some_dim_name"]. If this is still the case, then the type hint needs to left as it was before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right. I guess someone could pass a data cube with some "batch" or otherwise uninteresting dimension? I don't really see why you would ever do that, but I'm not gonna stop them. I'll change it back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, do you mind giving an example? I observe that in main branch, both pm.Data and pm.ConstantData take dims: Sequence[str] | None. Are these incorrect and Sequence[str] should be Sequence[str | None]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we allow mixed None / defined dims anymore

coords: dict[str, Sequence | np.ndarray] | None = None,
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. If variables can still have dims=[None, "some_dim_name"] then the old type hint is correct and the new one is wrong.

Comment on lines +498 to +508
@classmethod
def get_context(
cls, error_if_none: bool = True, allow_block_model_access: bool = False
) -> Model | None:
model = MODEL_MANAGER.current_context
if isinstance(model, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(model.error_msg_on_access)
if model is None and error_if_none:
raise TypeError("No model on context stack")
return model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit at a loss here. Why do we even need this class method? Can't it be replaced by a simple call to modelcontext(None)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on what you are suggesting?

In main branch, modelcontext always returns a model (or raises an error). I can modify modelcontext to take error_if_none: bool = True, allow_block_model_access: bool = False as arguments and include the logic you highlight above, as well as change the return type from Model to Model | None.

Then I'll need to add @overloads for those two arguments to maintain correct typing, since there are over 50 locations modelcontext is used. I'm happy to do this, but I want to be sure this is what you're intending.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are relevant instances/test where get_context are currently used:

assert pm.Model.get_context(error_if_none=False) is None

model = Model.get_context(error_if_none=False, allow_block_model_access=True)

pymc/model/core.py Show resolved Hide resolved
@thomasaarholt
Copy link
Contributor Author

@lucianopaz you tagged the wrong person ;)

Please see my reply about the get_context classmethod. Happy to get rid of more stuff (there are a few things more that are unused in model/core.py), but I'll prefer to do that in an other PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants