diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d6deacdb6e..f02afdc623 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -65,6 +65,7 @@ jobs: - | tests/distributions/test_continuous.py tests/distributions/test_multivariate.py + tests/distributions/moments/test_means.py - | tests/distributions/test_censored.py diff --git a/pymc/distributions/moments/__init__.py b/pymc/distributions/moments/__init__.py new file mode 100644 index 0000000000..8aafdb37a2 --- /dev/null +++ b/pymc/distributions/moments/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Moments dispatchers for pymc random variables.""" + +from pymc.distributions.moments.means import mean + +__all__ = ["mean"] diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py new file mode 100644 index 0000000000..f11ca6a8df --- /dev/null +++ b/pymc/distributions/moments/means.py @@ -0,0 +1,468 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mean dispatcher for pymc random variables.""" + +from functools import singledispatch + +import numpy as np + +from pytensor import tensor as pt +from pytensor.tensor.math import tanh +from pytensor.tensor.random.basic import ( + BernoulliRV, + BetaBinomialRV, + BetaRV, + BinomialRV, + CategoricalRV, + CauchyRV, + DirichletRV, + ExponentialRV, + GammaRV, + GeometricRV, + GumbelRV, + HalfCauchyRV, + HalfNormalRV, + HyperGeometricRV, + InvGammaRV, + LaplaceRV, + LogisticRV, + LogNormalRV, + MultinomialRV, + MvNormalRV, + NegBinomialRV, + NormalRV, + ParetoRV, + PoissonRV, + StudentTRV, + TriangularRV, + UniformRV, + VonMisesRV, +) +from pytensor.tensor.var import TensorVariable + +from pymc.distributions.continuous import ( + AsymmetricLaplaceRV, + BetaClippedRV, + ExGaussianRV, + FlatRV, + HalfFlatRV, + HalfStudentTRV, + KumaraswamyRV, + LogitNormalRV, + MoyalRV, + PolyaGammaRV, + RiceRV, + SkewNormalRV, + SkewStudentTRV, + WaldRV, + WeibullBetaRV, +) +from pymc.distributions.discrete import DiscreteUniformRV +from pymc.distributions.distribution import DiracDeltaRV +from pymc.distributions.mixture import MarginalMixtureRV +from pymc.distributions.moments.utils import UndefinedMomentException +from pymc.distributions.multivariate import ( + CARRV, + DirichletMultinomialRV, + KroneckerNormalRV, + LKJCorrRV, + MatrixNormalRV, + MvStudentTRV, + StickBreakingWeightsRV, + _LKJCholeskyCovRV, +) +from pymc.distributions.shape_utils import rv_size_is_none + +__all__ = ["mean"] + + +@singledispatch +def _mean(op, rv, *rv_inputs) -> TensorVariable: + raise NotImplementedError(f"Variable {rv} of type {op} has no mean implementation.") + + +def mean(rv: TensorVariable) -> TensorVariable: + """Compute the expected value of a random variable. + + The only parameter to this function is the RandomVariable + for which the value is to be derived. + """ + return _mean(rv.owner.op, rv, *rv.owner.inputs) + + +def maybe_resize(a: TensorVariable, size) -> TensorVariable: + if not rv_size_is_none(size): + a = pt.full(size, a) + return a + + +@_mean.register(BernoulliRV) +def bernoulli_mean(op, rv, rng, size, p): + return maybe_resize(p, size) + + +@_mean.register(BetaBinomialRV) +def betabinomial_mean(op, rv, rng, size, n, alpha, beta): + return maybe_resize((n * alpha) / (alpha + beta), size) + + +@_mean.register(BetaClippedRV) +def beta_clipped_mean(op, rv, rng, size, alpha, beta): + return maybe_resize(alpha / (alpha + beta), size) + + +@_mean.register(BetaRV) +def beta_mean(op, rv, rng, size, alpha, beta): + return maybe_resize(alpha / (alpha + beta), size) + + +@_mean.register(BinomialRV) +def binomial_mean(op, rv, rng, size, n, p): + return maybe_resize(n * p, size) + + +@_mean.register(CauchyRV) +def cauchy_mean(op, rv, rng, size, alpha, beta): + raise UndefinedMomentException("The mean of the Cauchy distribution is undefined") + + +@_mean.register(DirichletRV) +def dirichlet_mean(op, rv, rng, size, a): + norm_constant = pt.sum(a, axis=-1)[..., None] + mean = a / norm_constant + if not rv_size_is_none(size): + mean = pt.full(pt.concatenate([size, [a.shape[-1]]]), mean) + return mean + + +@_mean.register(ExponentialRV) +def exponential_mean(op, rv, rng, size, mu): + return maybe_resize(mu, size) + + +@_mean.register(FlatRV) +def flat_mean(op, rv, *args): + raise UndefinedMomentException("The mean of the Flat distribution is undefined") + + +@_mean.register(GammaRV) +def gamma_mean(op, rv, rng, size, alpha, inv_beta): + # The pytensor `GammaRV` `Op` inverts the `beta` parameter itself + return maybe_resize(alpha * inv_beta, size) + + +@_mean.register(GeometricRV) +def geometric_mean(op, rv, rng, size, p): + return maybe_resize(1.0 / p, size) + + +@_mean.register(GumbelRV) +def gumbel_mean(op, rv, rng, size, mu, beta): + return maybe_resize(mu + beta * np.euler_gamma, size) + + +@_mean.register(HalfCauchyRV) +def halfcauchy_mean(op, rv, rng, size, loc, beta): + raise UndefinedMomentException("The mean of the HalfCauchy distribution is undefined") + + +@_mean.register(HalfFlatRV) +def halfflat_mean(op, rv, *args): + raise UndefinedMomentException("The mean of the HalfFlat distribution is undefined") + + +@_mean.register(HalfNormalRV) +def halfnormal_mean(op, rv, rng, size, loc, sigma): + _, sigma = pt.broadcast_arrays(loc, sigma) + return maybe_resize(sigma * pt.sqrt(2 / np.pi), size) + + +@_mean.register(HyperGeometricRV) +def hypergeometric_mean(op, rv, rng, size, good, bad, n): + N, k = good + bad, good + return maybe_resize(n * k / N, size) + + +@_mean.register(InvGammaRV) +def invgamma_mean(op, rv, rng, size, alpha, beta): + return maybe_resize(pt.switch(alpha > 1, beta / (alpha - 1.0), np.nan), size) + + +@_mean.register(LaplaceRV) +def laplace_mean(op, rv, rng, size, mu, b): + return maybe_resize(pt.broadcast_arrays(mu, b)[0], size) + + +@_mean.register(LogisticRV) +def logistic_mean(op, rv, rng, size, mu, s): + return maybe_resize(pt.broadcast_arrays(mu, s)[0], size) + + +@_mean.register(LogitNormalRV) +def logitnormal_mean(op, rv, rng, size, mu, sigma): + raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined") + + +@_mean.register(LogNormalRV) +def lognormal_mean(op, rv, rng, size, mu, sigma): + return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size) + + +@_mean.register(MultinomialRV) +def multinomial_mean(op, rv, rng, size, n, p): + n = pt.shape_padright(n) + mean = n * p + if not rv_size_is_none(size): + output_size = pt.concatenate([size, [p.shape[-1]]]) + mean = pt.full(output_size, mean) + return mean + + +@_mean.register(MvNormalRV) +def mvnormal_mean(op, rv, rng, size, mu, cov): + mean = mu + if not rv_size_is_none(size): + mean_size = pt.concatenate([size, [mu.shape[-1]]]) + mean = pt.full(mean_size, mu) + return mean + + +@_mean.register(NegBinomialRV) +def negative_binomial_mean(op, rv, rng, size, n, p): + return maybe_resize(n * (1 - p) / p, size) + + +@_mean.register(NormalRV) +def normal_mean(op, rv, rng, size, mu, sigma): + return maybe_resize(pt.broadcast_arrays(mu, sigma)[0], size) + + +@_mean.register(ParetoRV) +def pareto_mean(op, rv, rng, size, alpha, m): + return maybe_resize(pt.switch(alpha > 1, alpha * m / (alpha - 1), np.nan), size) + + +@_mean.register(PoissonRV) +def poisson_mean(op, rv, rng, size, mu): + return maybe_resize(mu, size) + + +@_mean.register(TriangularRV) +def triangular_mean(op, rv, rng, size, lower, c, upper): + return maybe_resize((lower + upper + c) / 3, size) + + +@_mean.register(UniformRV) +def uniform_mean(op, rv, rng, size, lower, upper): + return maybe_resize((lower + upper) / 2, size) + + +@_mean.register(VonMisesRV) +def vonmisses_mean(op, rv, rng, size, mu, kappa): + return maybe_resize(pt.broadcast_arrays(mu, kappa)[0], size) + + +@_mean.register(KumaraswamyRV) +def kumaraswamy_mean(op, rv, rng, size, a, b): + return maybe_resize( + pt.exp(pt.log(b) + pt.gammaln(1 + 1 / a) + pt.gammaln(b) - pt.gammaln(1 + 1 / a + b)), + size, + ) + + +@_mean.register(WaldRV) +def wald_mean(op, rv, rng, size, mu, lam, alpha): + return maybe_resize(pt.broadcast_arrays(mu, lam, alpha)[0], size) + + +@_mean.register(WeibullBetaRV) +def weibull_mean(op, rv, rng, size, alpha, beta): + return maybe_resize(beta * pt.gamma(1 + 1 / alpha), size) + + +@_mean.register(AsymmetricLaplaceRV) +def asymmetric_laplace_mean(op, rv, rng, size, b, kappa, mu): + return maybe_resize(mu - (kappa - 1 / kappa) / b, size) + + +@_mean.register(StudentTRV) +def studentt_mean(op, rv, rng, size, nu, mu, sigma): + return maybe_resize(pt.broadcast_arrays(mu, nu, sigma)[0], size) + + +@_mean.register(HalfStudentTRV) +def half_studentt_mean(op, rv, rng, size, nu, sigma): + return maybe_resize( + pt.switch( + nu > 1, + 2 + * sigma + * pt.sqrt(nu / np.pi) + * pt.exp(pt.gammaln(0.5 * (nu + 1)) - pt.gammaln(nu / 2) - pt.log(nu - 1)), + np.nan, + ), + size, + ) + + +@_mean.register(ExGaussianRV) +def exgaussian_mean(op, rv, rng, size, mu, nu, sigma): + mu, nu, _ = pt.broadcast_arrays(mu, nu, sigma) + return maybe_resize(mu + nu, size) + + +@_mean.register(SkewNormalRV) +def skew_normal_mean(op, rv, rng, size, mu, sigma, alpha): + return maybe_resize(mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha**2) ** 0.5, size) + + +@_mean.register(SkewStudentTRV) +def skew_studentt_mean(op, rv, rng, size, a, b, mu, sigma): + a, b, mu, _ = pt.broadcast_arrays(a, b, mu, sigma) + Et = mu + (a - b) * pt.sqrt(a + b) * pt.gamma(a - 0.5) * pt.gamma(b - 0.5) / ( + 2 * pt.gamma(a) * pt.gamma(b) + ) + if not rv_size_is_none(size): + Et = pt.full(size, Et) + return Et + + +@_mean.register(RiceRV) +def rice_mean(op, rv, rng, size, nu, sigma): + nu_sigma_ratio = -(nu**2) / (2 * sigma**2) + return maybe_resize( + sigma + * np.sqrt(np.pi / 2) + * pt.exp(nu_sigma_ratio / 2) + * ( + (1 - nu_sigma_ratio) * pt.i0(-nu_sigma_ratio / 2) + - nu_sigma_ratio * pt.i1(-nu_sigma_ratio / 2) + ), + size, + ) + + +@_mean.register(MoyalRV) +def moyal_mean(op, rv, rng, size, mu, sigma): + return maybe_resize(mu + sigma * (np.euler_gamma + pt.log(2)), size) + + +@_mean.register(PolyaGammaRV) +def polya_gamma_mean(op, rv, rng, size, h, z): + return maybe_resize(pt.switch(pt.eq(z, 0), h / 4, tanh(z / 2) * (h / (2 * z))), size) + + +@_mean.register(CategoricalRV) +def categorical_mean(op, rv, *args): + raise UndefinedMomentException("The mean of the Categorical distribution is undefined") + + +@_mean.register(DiscreteUniformRV) +def discrete_uniform_mean(op, rv, rng, size, lower, upper): + return maybe_resize((upper + lower) / 2.0, size) + + +@_mean.register(DiracDeltaRV) +def dirac_delta_mean(op, rv, size, c): + return maybe_resize(c, size) + + +@_mean.register(MarginalMixtureRV) +def marginal_mixture_mean(op, rv, rng, weights, *components): + ndim_supp = components[0].owner.op.ndim_supp + weights = pt.shape_padright(weights, ndim_supp) + mix_axis = -ndim_supp - 1 + + if len(components) == 1: + mean_components = mean(components[0]) + + else: + mean_components = pt.stack( + [mean(component) for component in components], + axis=mix_axis, + ) + + return pt.sum(weights * mean_components, axis=mix_axis) + + +@_mean.register(MvStudentTRV) +def mvstudentt_mean(op, rv, rng, size, nu, mu, scale): + mean = mu + if not rv_size_is_none(size): + mean_size = pt.concatenate([size, [mu.shape[-1]]]) + mean = pt.full(mean_size, mean) + return mean + + +@_mean.register(DirichletMultinomialRV) +def dirichlet_multinomial_mean(op, rv, rng, size, n, a): + mean = pt.shape_padright(n) * a / pt.sum(a, axis=-1, keepdims=True) + if not rv_size_is_none(size): + output_size = pt.concatenate([size, [a.shape[-1]]]) + # We can't use pt.full because output_size is symbolic + mean, _ = pt.broadcast_arrays(mean, pt.zeros(size)[..., None]) + return mean + + +@_mean.register(_LKJCholeskyCovRV) +def lkj_cholesky_cov_mean(op, rv, rng, n, eta, sd_dist): + diag_idxs = (pt.cumsum(pt.arange(1, n + 1)) - 1).astype("int32") + mean = pt.zeros_like(rv) + mean = pt.set_subtensor(mean[..., diag_idxs], 1) + return mean + + +@_mean.register(LKJCorrRV) +def lkj_corr_mean(op, rv, rng, size, *args): + return pt.full_like(rv, pt.eye(rv.shape[-1])) + + +@_mean.register(MatrixNormalRV) +def matrix_normal_mean(op, rv, rng, size, mu, rowchol, colchol): + return pt.full_like(rv, mu) + + +@_mean.register(KroneckerNormalRV) +def kronecker_normal_mean(op, rv, rng, size, mu, covs, chols, evds): + mean = mu + if not rv_size_is_none(size): + mean_size = pt.concatenate([size, mu.shape]) + mean = pt.full(mean_size, mu) + return mean + + +@_mean.register(CARRV) +def car_mean(op, rv, rng, size, mu, W, alpha, tau, W_is_valid): + return pt.full_like(rv, mu) + + +@_mean.register(StickBreakingWeightsRV) +def stick_breaking_mean(op, rv, rng, size, alpha, K): + K = K.squeeze() + alpha = alpha[..., np.newaxis] + mean = (alpha / (1 + alpha)) ** pt.arange(K) + mean *= 1 / (1 + alpha) + mean = pt.concatenate([mean, (alpha / (1 + alpha)) ** K], axis=-1) + if not rv_size_is_none(size): + mean_size = pt.concatenate( + [ + size, + [ + K + 1, + ], + ] + ) + mean = pt.full(mean_size, mean) + return mean diff --git a/pymc/distributions/moments/utils.py b/pymc/distributions/moments/utils.py new file mode 100644 index 0000000000..7df5f45f19 --- /dev/null +++ b/pymc/distributions/moments/utils.py @@ -0,0 +1,15 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +class UndefinedMomentException(Exception): + pass diff --git a/tests/distributions/moments/__init__.py b/tests/distributions/moments/__init__.py new file mode 100644 index 0000000000..ae0da7db23 --- /dev/null +++ b/tests/distributions/moments/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/distributions/moments/test_means.py b/tests/distributions/moments/test_means.py new file mode 100644 index 0000000000..87edc090ab --- /dev/null +++ b/tests/distributions/moments/test_means.py @@ -0,0 +1,273 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from scipy.stats import ( + bernoulli, + beta, + betabinom, + binom, + chi2, + dirichlet, + expon, + exponnorm, + gamma, + geom, + gumbel_r, + halfnorm, + hypergeom, + invgamma, + invgauss, + jf_skew_t, + laplace, + laplace_asymmetric, + logistic, + lognorm, + matrix_normal, + moyal, + multinomial, + multivariate_normal, + multivariate_t, + nbinom, + norm, + pareto, + poisson, + rice, + skewnorm, + t, + triang, + uniform, + vonmises, + weibull_min, +) + +from pymc import ( + CAR, + AsymmetricLaplace, + Bernoulli, + Beta, + BetaBinomial, + Binomial, + Categorical, + Cauchy, + ChiSquared, + DiracDelta, + Dirichlet, + DirichletMultinomial, + DiscreteUniform, + ExGaussian, + Exponential, + Flat, + Gamma, + Geometric, + Gumbel, + HalfCauchy, + HalfFlat, + HalfNormal, + HalfStudentT, + HyperGeometric, + InverseGamma, + KroneckerNormal, + Kumaraswamy, + Laplace, + LKJCholeskyCov, + LKJCorr, + Logistic, + LogitNormal, + LogNormal, + MatrixNormal, + Moyal, + Multinomial, + MvNormal, + MvStudentT, + NegativeBinomial, + Normal, + Pareto, + Poisson, + PolyaGamma, + Rice, + SkewNormal, + SkewStudentT, + StickBreakingWeights, + StudentT, + Triangular, + Uniform, + VonMises, + Wald, + Weibull, + ZeroInflatedBinomial, + ZeroInflatedNegativeBinomial, + ZeroInflatedPoisson, +) +from pymc.distributions.moments.means import mean +from pymc.distributions.moments.utils import UndefinedMomentException + + +@pytest.mark.parametrize( + ["dist", "scipy_equiv", "dist_params", "scipy_params"], + [ + [ + AsymmetricLaplace, + laplace_asymmetric, + {"kappa": 2, "mu": 0.2, "b": 1 / 1.2}, + {"kappa": 2, "loc": 0.2, "scale": 1.2}, + ], + [Bernoulli, bernoulli, {"p": 0.6}, {"p": 0.6}], + [Beta, beta, {"alpha": 3, "beta": 2}, {"a": 3, "b": 2}], + [BetaBinomial, betabinom, {"alpha": 3, "beta": 2, "n": 5}, {"a": 3, "b": 2, "n": 5}], + [Binomial, binom, {"p": 0.6, "n": 5}, {"p": 0.6, "n": 5}], + [ChiSquared, chi2, {"nu": 6}, {"df": 6}], + [Dirichlet, dirichlet, {"a": np.ones(4)}, {"alpha": np.ones(4)}], + [ExGaussian, exponnorm, {"mu": 0, "sigma": 1, "nu": 1}, {"loc": 0, "scale": 1, "K": 1}], + [Exponential, expon, {"lam": 1}, {"scale": 1}], + [Gamma, gamma, {"alpha": 4, "beta": 3}, {"a": 4, "scale": 1 / 3}], + [Geometric, geom, {"p": 0.1}, {"p": 0.1}], + [Gumbel, gumbel_r, {"mu": 2, "beta": 1}, {"loc": 2, "scale": 1}], + [HalfNormal, halfnorm, {"sigma": 1}, {"scale": 1}], + [HyperGeometric, hypergeom, {"N": 10, "k": 2, "n": 4}, {"M": 10, "n": 2, "N": 4}], + [InverseGamma, invgamma, {"alpha": 2, "beta": 2}, {"a": 2, "scale": 2}], + [Laplace, laplace, {"mu": 2, "b": 2}, {"loc": 2, "scale": 2}], + [Logistic, logistic, {"mu": 2, "s": 1}, {"loc": 2, "scale": 1}], + [LogNormal, lognorm, {"mu": 0.3, "sigma": 0.6}, {"scale": np.exp(0.3), "s": 0.6}], + [ + MatrixNormal, + matrix_normal, + {"mu": np.eye(3), "rowcov": np.eye(3), "colcov": np.eye(3)}, + {"mean": np.eye(3), "rowcov": np.eye(3), "colcov": np.eye(3)}, + ], + [Moyal, moyal, {"mu": 2, "sigma": 2}, {"loc": 2, "scale": 2}], + [Multinomial, multinomial, {"n": 20, "p": np.ones(6) / 6}, {"n": 20, "p": np.ones(6) / 6}], + [ + MvNormal, + multivariate_normal, + {"mu": np.ones(3), "cov": np.eye(3)}, + {"mean": np.ones(3), "cov": np.eye(3)}, + ], + [ + MvStudentT, + multivariate_t, + {"mu": np.ones(3), "cov": np.eye(3), "nu": 4}, + {"loc": np.ones(3), "shape": np.eye(3), "df": 4}, + ], + [NegativeBinomial, nbinom, {"n": 10, "p": 0.5}, {"n": 10, "p": 0.5}], + [Normal, norm, {"mu": 2, "sigma": 2}, {"loc": 2, "scale": 2}], + [Pareto, pareto, {"alpha": 5, "m": 2}, {"b": 5, "scale": 2}], + [Poisson, poisson, {"mu": 20}, {"mu": 20}], + pytest.param( + Rice, rice, {"b": 2, "sigma": 2}, {"b": 2, "scale": 2}, marks=pytest.mark.xfail + ), # Something is wrong with the Rice mean, maybe a Bessel function in pytensor? + [SkewNormal, skewnorm, {"mu": 2, "sigma": 2, "alpha": 2}, {"loc": 2, "scale": 2, "a": 2}], + [ + SkewStudentT, + jf_skew_t, + {"mu": 2, "sigma": 2, "a": 3, "b": 3}, + {"loc": 2, "scale": 2, "a": 3, "b": 3}, + ], + [StudentT, t, {"mu": 2, "sigma": 2, "nu": 6}, {"loc": 2, "scale": 2, "df": 6}], + [ + Triangular, + triang, + {"lower": -3, "upper": 2, "c": 1}, + {"loc": -3, "scale": 5, "c": 4 / 5}, + ], + [Uniform, uniform, {"lower": -3, "upper": 2}, {"loc": -3, "scale": 5}], + [VonMises, vonmises, {"mu": 2, "kappa": 2}, {"loc": 2, "kappa": 2}], + [Wald, invgauss, {"mu": 2, "lam": 1}, {"mu": 2, "scale": 1}], + [Weibull, weibull_min, {"alpha": 2, "beta": 2}, {"c": 2, "scale": 2}], + ], +) +def test_mean_equal_to_scipy(dist, scipy_equiv, dist_params, scipy_params): + rv = dist.dist(**dist_params) + pymc_mean = mean(rv).eval() + scipy_rv = scipy_equiv(**scipy_params) + try: + scipy_mean = scipy_rv.mean() + except TypeError: + # Happens for multivariate_normal + scipy_mean = scipy_rv.mean + except AttributeError: + # Happens for multivariate_t + scipy_mean = scipy_rv.loc + assert np.asarray(pymc_mean).shape == np.asarray(scipy_mean).shape + np.testing.assert_almost_equal(pymc_mean, scipy_mean) + pymc_mean_tiled = mean(dist.dist(shape=(3, *pymc_mean.shape), **dist_params)).eval() + np.testing.assert_almost_equal( + pymc_mean_tiled, np.tile(pymc_mean, (3,) + (1,) * pymc_mean.ndim) + ) + + +@pytest.mark.parametrize( + ["dist", "dist_params", "expected"], + [ + [CAR, {"mu": np.ones(3), "W": np.eye(3), "alpha": 0.5, "tau": 1}, np.ones(3)], + [DiracDelta, {"c": 4.0}, 4.0], + [DirichletMultinomial, {"n": 5, "a": np.ones(5)}, np.ones(5)], + [DiscreteUniform, {"lower": 3, "upper": 5}, 4.0], + [HalfStudentT, {"nu": 2, "sigma": np.sqrt(2)}, 2.0], + [ + KroneckerNormal, + { + "mu": np.ones(6), + "covs": [ + np.array([[1.0, 0.5], [0.5, 2]]), + np.array([[1.0, 0.4, 0.2], [0.4, 2, 0.3], [0.2, 0.3, 1]]), + ], + }, + np.ones(6), + ], + [Kumaraswamy, {"a": 1, "b": 1}, 0.5], + [ + LKJCholeskyCov, + {"eta": 1, "n": 3, "sd_dist": DiracDelta.dist(1), "compute_corr": False}, + np.eye(3)[np.tril_indices(3)], + ], + [LKJCorr, {"eta": 1, "n": 3}, np.eye(3)], + [PolyaGamma, {"h": 1, "z": 1}, 0.23105858], + [ + StickBreakingWeights, + {"alpha": 1, "K": 5}, + np.concatenate([0.5 ** np.arange(1, 6), [0.5**5]]), + ], + [ZeroInflatedBinomial, {"n": 10, "p": 0.5, "psi": 0.8}, 4.0], + [ZeroInflatedNegativeBinomial, {"n": 10, "p": 0.5, "psi": 0.8}, 8.0], + [ZeroInflatedPoisson, {"mu": 5, "psi": 0.8}, 4.0], + ], +) +def test_mean_equal_expected(dist, dist_params, expected): + expected = np.asarray(expected) + rv = dist.dist(**dist_params) + pymc_mean = mean(rv).eval() + np.testing.assert_almost_equal(pymc_mean, expected) + pymc_mean_tiled = mean(dist.dist(shape=(3, *pymc_mean.shape), **dist_params)).eval() + np.testing.assert_almost_equal( + pymc_mean_tiled, np.tile(pymc_mean, (3,) + (1,) * pymc_mean.ndim) + ) + + +@pytest.mark.parametrize( + ["dist", "dist_params"], + [ + [Cauchy, {"alpha": 1, "beta": 1}], + [HalfCauchy, {"beta": 1.0}], + [LogitNormal, {"mu": 2, "sigma": 1}], + [Flat, {}], + [HalfFlat, {}], + [Categorical, {"p": [0.1, 0.9]}], + ], +) +def test_no_mean(dist, dist_params): + with pytest.raises(UndefinedMomentException): + mean(dist.dist(**dist_params))