Skip to content

Commit

Permalink
Add _InverseTransform (#875)
Browse files Browse the repository at this point in the history
* add inverse transform

* also use _inverse in TFP
  • Loading branch information
fehiepsi authored Jan 15, 2021
1 parent 85289ef commit c8e9d67
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 17 deletions.
2 changes: 1 addition & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def codomain(self):
def __call__(self, x):
return self.bijector.forward(x)

def inv(self, y):
def _inverse(self, y):
return self.bijector.inverse(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down
73 changes: 57 additions & 16 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ def input_event_dim(self):
def output_event_dim(self):
return self.event_dim

@property
def inv(self):
return _InverseTransform(self)

def __call__(self, x):
return NotImplementedError

def inv(self, y):
def _inverse(self, y):
raise NotImplementedError

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -75,6 +79,43 @@ def call_with_intermediates(self, x):
return self(x), None


class _InverseTransform(Transform):
def __init__(self, transform):
super().__init__()
self._inv = transform

@property
def domain(self):
return self._inv.codomain

@property
def codomain(self):
return self._inv.domain

@property
def input_event_dim(self):
return self._inv.output_event_dim

@property
def output_event_dim(self):
return self._inv.input_event_dim

@property
def event_dim(self):
return self._inv.event_dim

@property
def inv(self):
return self._inv

def __call__(self, x):
return self._inv._inverse(x)

def log_abs_det_jacobian(self, x, y, intermediates=None):
# NB: we don't use intermediates for inverse transform
return -self._inv.log_abs_det_jacobian(y, x, None)


class AbsTransform(Transform):
domain = constraints.real
codomain = constraints.positive
Expand All @@ -85,7 +126,7 @@ def __eq__(self, other):
def __call__(self, x):
return jnp.abs(x)

def inv(self, y):
def _inverse(self, y):
return y


Expand Down Expand Up @@ -134,7 +175,7 @@ def event_dim(self):
def __call__(self, x):
return self.loc + self.scale * x

def inv(self, y):
def _inverse(self, y):
return (y - self.loc) / self.scale

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -176,7 +217,7 @@ def __call__(self, x):
x = part(x)
return x

def inv(self, y):
def _inverse(self, y):
for part in self.parts[::-1]:
y = part.inv(y)
return y
Expand Down Expand Up @@ -255,7 +296,7 @@ def __call__(self, x):
t = jnp.tanh(x)
return signed_stick_breaking_tril(t)

def inv(self, y):
def _inverse(self, y):
# inverse stick-breaking
z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
pad_width = [(0, 0)] * y.ndim
Expand Down Expand Up @@ -306,7 +347,7 @@ def __call__(self, x):
# XXX consider to clamp from below for stability if necessary
return jnp.exp(x)

def inv(self, y):
def _inverse(self, y):
return jnp.log(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -321,7 +362,7 @@ def __init__(self, event_dim=0):
def __call__(self, x):
return x

def inv(self, y):
def _inverse(self, y):
return y

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -349,7 +390,7 @@ def codomain(self):
def __call__(self, x):
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))

def inv(self, y):
def _inverse(self, y):
return jnp.linalg.cholesky(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand Down Expand Up @@ -387,7 +428,7 @@ def __init__(self, loc, scale_tril):
def __call__(self, x):
return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1)

def inv(self, y):
def _inverse(self, y):
y = y - self.loc
original_shape = jnp.shape(y)
yt = jnp.reshape(y, (-1, original_shape[-1])).T
Expand Down Expand Up @@ -415,7 +456,7 @@ def __call__(self, x):
diag = jnp.exp(x[..., -n:])
return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)

def inv(self, y):
def _inverse(self, y):
z = matrix_to_tril_vec(y, diagonal=-1)
return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)

Expand All @@ -442,7 +483,7 @@ def __call__(self, x):
z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
return jnp.cumsum(z, axis=-1)

def inv(self, y):
def _inverse(self, y):
x = jnp.log(y[..., 1:] - y[..., :-1])
return jnp.concatenate([y[..., :1], x], axis=-1)

Expand All @@ -461,7 +502,7 @@ def __init__(self, permutation):
def __call__(self, x):
return x[..., self.permutation]

def inv(self, y):
def _inverse(self, y):
size = self.permutation.size
permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)),
self.permutation,
Expand All @@ -482,7 +523,7 @@ def __init__(self, exponent):
def __call__(self, x):
return jnp.power(x, self.exponent)

def inv(self, y):
def _inverse(self, y):
return jnp.power(y, 1 / self.exponent)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -495,7 +536,7 @@ class SigmoidTransform(Transform):
def __call__(self, x):
return _clipped_expit(x)

def inv(self, y):
def _inverse(self, y):
return logit(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -522,7 +563,7 @@ def __call__(self, x):
z1m_cumprod_shifted = jnp.pad(z1m_cumprod, pad_width, mode="constant", constant_values=1.)
return z_padded * z1m_cumprod_shifted

def inv(self, y):
def _inverse(self, y):
y_crop = y[..., :-1]
z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny)
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
Expand Down Expand Up @@ -559,7 +600,7 @@ def __call__(self, x):
else:
return self.unpack_fn(x)

def inv(self, y):
def _inverse(self, y):
leading_dims = [v.shape[0] if jnp.ndim(v) > 0 else 0
for v in tree_flatten(y)[0]]
d0 = leading_dims[0]
Expand Down
4 changes: 4 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,12 +1031,16 @@ def test_bijective_transforms(transform, event_shape, batch_shape):
# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
assert transform.inv.inv is transform
assert transform.domain is transform.inv.codomain
assert transform.codomain is transform.inv.domain

# test domain
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

# test log_abs_det_jacobian
actual = transform.log_abs_det_jacobian(x, y)
assert_allclose(actual, -transform.inv.log_abs_det_jacobian(y, x))
assert jnp.shape(actual) == batch_shape
if len(shape) == transform.event_dim:
if len(event_shape) == 1:
Expand Down

0 comments on commit c8e9d67

Please sign in to comment.