diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4b26bb51c8..fec0d7c145 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -246,7 +246,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: else: varnames = ", ".join( [ - get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name + get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[arg-type, misc] for v in s.vars ] ) diff --git a/pymc/util.py b/pymc/util.py index 6d4b1bf4de..42835915f6 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,9 +14,9 @@ import warnings -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence from copy import deepcopy -from typing import NewType, cast +from typing import Any, NewType, TypeVar, cast import arviz import cloudpickle @@ -32,8 +32,10 @@ from pymc.exceptions import BlockModelAccessError +T = TypeVar("T") -def __getattr__(name): + +def __getattr__(name: str) -> Callable: if name == "dataset_to_point_list": warnings.warn( f"{name} has been moved to backends.arviz. Importing from util will fail in a future release.", @@ -160,8 +162,8 @@ def tree_contains(self, item): return dict.__contains__(self, item) -def get_transformed_name(name, transform): - r""" +def get_transformed_name(name: str, transform) -> str: + """ Consistent way of transforming names. Parameters @@ -179,8 +181,8 @@ def get_transformed_name(name, transform): return f"{name}_{transform.name}__" -def is_transformed_name(name): - r""" +def is_transformed_name(name: str) -> bool: + """ Quickly check if a name was transformed with `get_transformed_name`. Parameters @@ -196,8 +198,8 @@ def is_transformed_name(name): return name.endswith("__") and name.count("_") >= 3 -def get_untransformed_name(name): - r""" +def get_untransformed_name(name: str) -> str: + """ Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed. Parameters @@ -215,8 +217,13 @@ def get_untransformed_name(name): return "_".join(name.split("_")[:-3]) -def get_default_varnames(var_iterator, include_transformed): - r"""Extract default varnames from a trace. +VarOrVarName = TypeVar("VarOrVarName", Variable, str) + + +def get_default_varnames( + var_iterator: Iterable[VarOrVarName], include_transformed: bool +) -> list[VarOrVarName]: + """Extract default varnames from a trace. Parameters ---------- @@ -236,7 +243,7 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> VarName: +def get_var_name(var: VarOrVarName) -> VarName: """Get an appropriate, plain variable name for a variable.""" return VarName(str(getattr(var, "name", var))) @@ -280,7 +287,7 @@ def chains_and_samples(data: xarray.Dataset | arviz.InferenceData) -> tuple[int, return nchains, nsamples -def hashable(a=None) -> int: +def hashable(a: Any = None) -> int: """ Hash many kinds of objects, including some that are unhashable through the builtin `hash` function. @@ -516,8 +523,8 @@ def _add_future_warning_tag(var) -> None: var.tag = new_tag -def makeiter(a): - if isinstance(a, tuple | list): +def makeiter(a: Sequence[T] | T) -> Sequence[T]: + if isinstance(a, Sequence): return a else: return [a]