Skip to content

Commit

Permalink
implemented inverse gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed Jul 22, 2024
1 parent 373d238 commit ca701e0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
11 changes: 5 additions & 6 deletions src/ensemble/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,21 @@ def _create_scipy_dist(self) -> None:

class InvGamma(Distribution):
def _create_scipy_dist(self) -> None:
raise NotImplementedError
res = scipy.optimize.minimize(
fun=self._shape_scale,
# why is this the initial guess? idk either
x0=[self.mean, self.mean * np.sqrt(self.variance)],
# a *good* friend told me that this is a good initial guess and it works so far???
x0=[3, self.mean * 2],
args=(self.mean, self.variance),
method="Nelder-Mead",
)
print("results from minimizer: ", res.x)
self._scipy_dist = scipy.stats.invgamma(a=res.x[0], scale=res.x[1])
shape, scale = np.abs(res.x)
self._scipy_dist = scipy.stats.invgamma(a=shape, scale=scale)

def _shape_scale(self, x, samp_mean, samp_var) -> None:
alpha = x[0]
beta = x[1]
return ((beta / (alpha - 1)) - samp_mean) ** 2 + (
beta**2 / ((alpha - 1) ** 2 * (alpha - 2)) - samp_var
(beta**2 / ((alpha - 1) ** 2 * (alpha - 2))) - samp_var
) ** 2


Expand Down
11 changes: 6 additions & 5 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# @pytest.mark.parametrize("a, b, expected", [(1, 2, 3), (2, 3, 5)])
# def test_add(a, b, expected):
# assert add(a, b) == expected
MEAN = 0.5
VARIANCE = 1.1
MEAN = 5
VARIANCE = 6.1


def test_exp():
Expand All @@ -40,11 +40,12 @@ def test_gamma():


def test_invgamma():
raise NotImplementedError
# raise NotImplementedError
invgamma = InvGamma(MEAN, VARIANCE)
res = invgamma.stats(moments="mv")
print(res)
assert False
print("mean and var: ", res)
assert np.isclose(res[0], MEAN)
assert np.isclose(res[1], VARIANCE)


def test_fisk():
Expand Down

0 comments on commit ca701e0

Please sign in to comment.