Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some type hinting to help with migrating Distribution #7484

Merged
merged 8 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@
)
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable

from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf
from pymc.logprob.basic import TensorLike, icdf
from pymc.pytensorf import normalize_rng_param

try:
Expand Down Expand Up @@ -152,7 +152,7 @@ class BoundedContinuous(Continuous):
"""Base class for bounded continuous distributions"""

# Indices of the arguments that define the lower and upper bounds of the distribution
bound_args_indices: list[int] | None = None
bound_args_indices: tuple[int | None, int | None] | None = None


@_default_transform.register(PositiveContinuous)
Expand Down Expand Up @@ -214,7 +214,9 @@ def assert_negative_support(var, label, distname, value=-1e-6):
return Assert(msg)(var, pt.all(pt.ge(var, 0.0)))


def get_tau_sigma(tau=None, sigma=None):
def get_tau_sigma(
tau: TensorLike | None = None, sigma: TensorLike | None = None
) -> tuple[TensorVariable, TensorVariable]:
r"""
Find precision and standard deviation. The link between the two
parameterizations is given by the inverse relationship:
Expand All @@ -241,13 +243,14 @@ def get_tau_sigma(tau=None, sigma=None):
sigma = pt.as_tensor_variable(1.0)
tau = pt.as_tensor_variable(1.0)
elif tau is None:
assert sigma is not None # Just for type checker
sigma = pt.as_tensor_variable(sigma)
# Keep tau negative, if sigma was negative, so that it will
# fail when used
tau = (sigma**-2.0) * pt.sign(sigma)
else:
tau = pt.as_tensor_variable(tau)
# Keep tau negative, if sigma was negative, so that it will
# Keep sigma negative, if tau was negative, so that it will
# fail when used
sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)

Expand Down
10 changes: 7 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABCMeta
from collections.abc import Callable, Sequence
from functools import singledispatch
from typing import TypeAlias
from typing import Any, TypeAlias

import numpy as np

Expand Down Expand Up @@ -423,8 +423,12 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

rv_op: [RandomVariable, SymbolicRandomVariable] = None
rv_type: MetaType = None
# rv_op and _type are set to None via the DistributionMeta.__new__
# if not specified as class attributes in subclasses of Distribution.
# rv_op can either be a class (see the Normal class) or a method
# (see the Censored class), both callable to return a TensorVariable.
rv_op: Any = None
rv_type: MetaType | None = None

def __new__(
cls,
Expand Down
Loading