Skip to content

Commit

Permalink
Sum Layer CDF
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 7, 2024
1 parent f580539 commit 79a13ae
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
52 changes: 48 additions & 4 deletions src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,33 @@ def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array:
"""
Calculate the log-likelihood of the distribution.
.. Note::
The shape of the log likelihood depends on the number of samples and nodes.
The shape of the result is (#samples, #nodes).
:param x: The input vector.
:return: The log-likelihood of every node in the layer for x.
"""
raise NotImplementedError

def log_likelihood_of_nodes(self, x: jnp.array) -> jnp.array:
"""
Vectorized version of :meth:`log_likelihood_of_nodes_single`
"""
return jax.vmap(self.log_likelihood_of_nodes_single)(x)


def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array:
"""
Calculate the cumulative distribution function of the distribution if applicable.
:param x: The input vector.
:return: The cumulative distribution function of every node in the layer for x.
"""
raise NotImplementedError

def cdf_of_nodes(self, x: jnp.array) -> jnp.array:
"""
Vectorized version of :meth:`cdf_of_nodes_single`
"""
return jax.vmap(self.cdf_of_nodes_single)(x)

def validate(self):
"""
Validate the parameters and their layouts.
Expand Down Expand Up @@ -139,6 +157,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Probabilis
def partition(self) -> Tuple[Any, Any]:
"""
Partition the layer into the parameters and the static structure.
:return: A tuple containing the parameters and the static structure as pytrees.
"""
return eqx.partition(self, eqx.is_inexact_array)
Expand Down Expand Up @@ -320,7 +339,7 @@ def normalized_weights(self):
return result

def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array:
result = jnp.zeros(self.number_of_nodes)
result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32)

for log_weights, child_layer in self.log_weighted_child_layers:
# get the log likelihoods of the child nodes
Expand All @@ -341,6 +360,31 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array:

return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf)

def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array:
result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32)

for log_weights, child_layer in self.log_weighted_child_layers:
# get the cdf of the child nodes
child_layer_cdf = child_layer.cdf_of_nodes_single(x)

# weight the cdf of the child nodes by the weight for each node of this layer
cloned_log_weights = copy_bcoo(log_weights) # clone the weights

# multiply the weights with the child layer cdf
cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights
cloned_log_weights.data *= child_layer_cdf[cloned_log_weights.indices[:, 1]]

# sum the weights for each node
ll = cloned_log_weights.sum(1).todense()

# sum the child layer result
result += ll

# normalize the result
normalization_constants = jnp.exp(self.log_normalization_constants)
return result / normalization_constants



def sample_from_frequencies(self, frequencies: jax.Array, key: jax.random.PRNGKey) -> BCOO:
# calculate the probabilities for the latent variable interpretation of this layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ class ContinuousLayer(InputLayer, ABC):
Abstract base class for continuous univariate input units.
"""

def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array:
raise NotImplementedError

def cdf_of_nodes(self, x: jnp.array) -> jnp.array:
return jax.vmap(self.cdf_of_nodes_single)(x)


class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
Expand Down
17 changes: 16 additions & 1 deletion test/test_jax/test_sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jax.experimental.sparse import BCOO
from random_events.variable import Continuous
import jax.numpy as jnp
from triton.language import dtype

from probabilistic_model.probabilistic_circuit.jax import in_bound_elements_from_sparse_slice
from probabilistic_model.probabilistic_circuit.jax.input_layer import DiracDeltaLayer
Expand Down Expand Up @@ -69,4 +70,18 @@ def test_sampling(self):
_, sample_row = in_bound_elements_from_sparse_slice(sample_row)
self.assertEqual(len(sample_row), frequencies[index])
likelihood = self.sum_layer.log_likelihood_of_nodes(sample_row)
self.assertTrue(all(likelihood[:, index] > -jnp.inf))
self.assertTrue(all(likelihood[:, index] > -jnp.inf))

def test_cdf(self):
data = jnp.arange(7, dtype=jnp.float32).reshape(-1, 1) - 0.5
cdf = self.sum_layer.cdf_of_nodes(data)
self.assertEqual(cdf.shape, (7, 2))
result = jnp.array([[0, 0], # -0.5
[0, 0.4], # 0.5
[0.1, 0.4], # 1.5
[0.3, 0.7], # 2.5
[0.6, 0.7], # 3.5
[0.6, 0.8], # 4.5
[1, 1], # 5.5
], dtype=jnp.float32)
self.assertTrue(jnp.allclose(cdf, result))

0 comments on commit 79a13ae

Please sign in to comment.