Skip to content

Commit

Permalink
Simplify ContextMeta
Browse files Browse the repository at this point in the history
We only keep the __call__ method, which is necessary to keep the
model context itself active during that model's __init__.
  • Loading branch information
thomasaarholt committed Aug 23, 2024
1 parent 6c8202c commit d203e0f
Showing 1 changed file with 12 additions and 52 deletions.
64 changes: 12 additions & 52 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import warnings

from collections.abc import Iterable, Sequence
from sys import modules
from typing import (
Literal,
cast,
Expand Down Expand Up @@ -120,57 +119,6 @@ def parent_context(self) -> Model | None:
MODEL_MANAGER = ModelManager()


class ContextMeta(type):
"""Functionality for objects that put themselves in a context using
the `with` statement.
"""

# FIXME: is there a more elegant way to automatically add methods to the class that
# are instance methods instead of class methods?
def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs):
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
if context_class is not None:
cls._context_class = context_class
super().__init__(name, bases, nmspc)

# the following complex property accessor is necessary because the
# context_class may not have been created at the point it is
# specified, so the context_class may be a class *name* rather
# than a class.
@property
def context_class(cls) -> type:
def resolve_type(c: type | str) -> type:
if isinstance(c, str):
c = getattr(modules[cls.__module__], c)
if isinstance(c, type):
return c
raise ValueError(f"Cannot resolve context class {c}")

assert cls is not None
if isinstance(cls._context_class, str):
cls._context_class = resolve_type(cls._context_class)
if not isinstance(cls._context_class, str | type):
raise ValueError(
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type"
)
return cls._context_class

# Inherit context class from parent
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.context_class = super().context_class

# Initialize object in its own context...
# Merged from InitContextMeta in the original.
def __call__(cls, *args, **kwargs):
# We type hint Model here so type checkers understand that Model is a context manager.
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
instance: Model = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def modelcontext(model: Model | None) -> Model:
"""
Return the given model or, if none was supplied, try to find one in
Expand Down Expand Up @@ -343,6 +291,18 @@ def profile(self):
return self._pytensor_function.profile


class ContextMeta(type):
"""A metaclass in order to apply a model's context during `Model.__init__``."""

# We want the Model's context to be active during __init__. In order for this
# to apply to subclasses of Model as well, we need to use a metaclass.
def __call__(cls: type[Model], *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs) # type: ignore
with instance: # applies context
instance.__init__(*args, **kwargs)
return instance


class Model(WithMemoization, metaclass=ContextMeta):
"""Encapsulates the variables and likelihood factors of a model.
Expand Down

0 comments on commit d203e0f

Please sign in to comment.