Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _InverseTransform #875

Merged
merged 2 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def codomain(self):
def __call__(self, x):
return self.bijector.forward(x)

def inv(self, y):
def _inverse(self, y):
return self.bijector.inverse(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down
73 changes: 57 additions & 16 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ def input_event_dim(self):
def output_event_dim(self):
return self.event_dim

@property
def inv(self):
return _InverseTransform(self)

def __call__(self, x):
return NotImplementedError

def inv(self, y):
def _inverse(self, y):
raise NotImplementedError

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -75,6 +79,43 @@ def call_with_intermediates(self, x):
return self(x), None


class _InverseTransform(Transform):
def __init__(self, transform):
super().__init__()
self._inv = transform

@property
def domain(self):
return self._inv.codomain

@property
def codomain(self):
return self._inv.domain

@property
def input_event_dim(self):
return self._inv.output_event_dim

@property
def output_event_dim(self):
return self._inv.input_event_dim

@property
def event_dim(self):
return self._inv.event_dim

@property
def inv(self):
return self._inv

def __call__(self, x):
return self._inv._inverse(x)

def log_abs_det_jacobian(self, x, y, intermediates=None):
# NB: we don't use intermediates for inverse transform
return -self._inv.log_abs_det_jacobian(y, x, None)


class AbsTransform(Transform):
domain = constraints.real
codomain = constraints.positive
Expand All @@ -85,7 +126,7 @@ def __eq__(self, other):
def __call__(self, x):
return jnp.abs(x)

def inv(self, y):
def _inverse(self, y):
return y


Expand Down Expand Up @@ -134,7 +175,7 @@ def event_dim(self):
def __call__(self, x):
return self.loc + self.scale * x

def inv(self, y):
def _inverse(self, y):
return (y - self.loc) / self.scale

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -176,7 +217,7 @@ def __call__(self, x):
x = part(x)
return x

def inv(self, y):
def _inverse(self, y):
for part in self.parts[::-1]:
y = part.inv(y)
return y
Expand Down Expand Up @@ -255,7 +296,7 @@ def __call__(self, x):
t = jnp.tanh(x)
return signed_stick_breaking_tril(t)

def inv(self, y):
def _inverse(self, y):
# inverse stick-breaking
z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
pad_width = [(0, 0)] * y.ndim
Expand Down Expand Up @@ -306,7 +347,7 @@ def __call__(self, x):
# XXX consider to clamp from below for stability if necessary
return jnp.exp(x)

def inv(self, y):
def _inverse(self, y):
return jnp.log(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -321,7 +362,7 @@ def __init__(self, event_dim=0):
def __call__(self, x):
return x

def inv(self, y):
def _inverse(self, y):
return y

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -349,7 +390,7 @@ def codomain(self):
def __call__(self, x):
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))

def inv(self, y):
def _inverse(self, y):
return jnp.linalg.cholesky(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -387,7 +428,7 @@ def __init__(self, loc, scale_tril):
def __call__(self, x):
return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1)

def inv(self, y):
def _inverse(self, y):
y = y - self.loc
original_shape = jnp.shape(y)
yt = jnp.reshape(y, (-1, original_shape[-1])).T
Expand Down Expand Up @@ -415,7 +456,7 @@ def __call__(self, x):
diag = jnp.exp(x[..., -n:])
return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)

def inv(self, y):
def _inverse(self, y):
z = matrix_to_tril_vec(y, diagonal=-1)
return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)

Expand All @@ -442,7 +483,7 @@ def __call__(self, x):
z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
return jnp.cumsum(z, axis=-1)

def inv(self, y):
def _inverse(self, y):
x = jnp.log(y[..., 1:] - y[..., :-1])
return jnp.concatenate([y[..., :1], x], axis=-1)

Expand All @@ -461,7 +502,7 @@ def __init__(self, permutation):
def __call__(self, x):
return x[..., self.permutation]

def inv(self, y):
def _inverse(self, y):
size = self.permutation.size
permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)),
self.permutation,
Expand All @@ -482,7 +523,7 @@ def __init__(self, exponent):
def __call__(self, x):
return jnp.power(x, self.exponent)

def inv(self, y):
def _inverse(self, y):
return jnp.power(y, 1 / self.exponent)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -495,7 +536,7 @@ class SigmoidTransform(Transform):
def __call__(self, x):
return _clipped_expit(x)

def inv(self, y):
def _inverse(self, y):
return logit(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -522,7 +563,7 @@ def __call__(self, x):
z1m_cumprod_shifted = jnp.pad(z1m_cumprod, pad_width, mode="constant", constant_values=1.)
return z_padded * z1m_cumprod_shifted

def inv(self, y):
def _inverse(self, y):
y_crop = y[..., :-1]
z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny)
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
Expand Down Expand Up @@ -559,7 +600,7 @@ def __call__(self, x):
else:
return self.unpack_fn(x)

def inv(self, y):
def _inverse(self, y):
leading_dims = [v.shape[0] if jnp.ndim(v) > 0 else 0
for v in tree_flatten(y)[0]]
d0 = leading_dims[0]
Expand Down
4 changes: 4 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,12 +1031,16 @@ def test_bijective_transforms(transform, event_shape, batch_shape):
# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
assert transform.inv.inv is transform
assert transform.domain is transform.inv.codomain
assert transform.codomain is transform.inv.domain

# test domain
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

# test log_abs_det_jacobian
actual = transform.log_abs_det_jacobian(x, y)
assert_allclose(actual, -transform.inv.log_abs_det_jacobian(y, x))
assert jnp.shape(actual) == batch_shape
if len(shape) == transform.event_dim:
if len(event_shape) == 1:
Expand Down