Skip to content

Commit

Permalink
Started to work on coupling circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Sep 30, 2024
1 parent 44e8ba2 commit 212926a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import abstractmethod

import jax
from jax.tree_util import tree_flatten, tree_unflatten

import equinox as eqx

from .probabilistic_circuit import Layer


class Conditioner:

@abstractmethod
def generate_parameters(self, x) -> jax.Array:
raise NotImplementedError

@property
@abstractmethod
def output_length(self):
"""
:return: The length number of parameters that the model outputs.
"""
raise NotImplementedError


class CouplingCircuit(eqx.Module):

conditioner: Conditioner
circuit: Layer

def __init__(self, conditioner: Conditioner, circuit: Layer):
self.conditioner = conditioner
self.circuit = circuit

def partition_circuit(self):
return eqx.partition(self.circuit, eqx.is_inexact_array)

def validate(self):
self.circuit.validate()
params, _ = self.partition_circuit()
flattened_params = tree_flatten(params)[0]
number_of_parameters = sum([len(p) for p in flattened_params])
assert number_of_parameters == self.conditioner.output_length

def conditional_log_likelihood(self, x):
tree_def, static = self.partition_circuit()
flat_model, treedef_model = jax.tree_util.tree_flatten(tree_def)
params = self.conditioner.generate_parameters(x)
params = tree_unflatten(treedef_model, [params[0]])
circuit = eqx.combine(params, static)
return circuit.log_likelihood_of_nodes(x)
38 changes: 38 additions & 0 deletions test/test_jax/test_coupling_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import numpy as np

from probabilistic_model.probabilistic_circuit.jax import UniformLayer, SumLayer
from probabilistic_model.probabilistic_circuit.jax.coupling_circuit import Conditioner, CouplingCircuit

import equinox as eqx
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

class TrivialConditioner(Conditioner):

def generate_parameters(self, x):
return jnp.log(jnp.array([[0.2, 0.8]]).repeat(x.shape[0], 0))

@property
def output_length(self):
return 2

class CouplingCircuitTestCase(unittest.TestCase):

data = np.vstack((np.random.uniform(0, 1, (100, 1)),
np.random.uniform(2, 3, (200, 1))))
uniform_layer = UniformLayer(0, jnp.array([[-0.01, 1.01],
[1.99, 3.01]]))
sum_layer = SumLayer([uniform_layer], [BCOO((jnp.array([0., 0.]),
jnp.array([[0, 0], [0, 1]])),
shape=(1, 2))])

cc = CouplingCircuit(TrivialConditioner(), sum_layer)
cc.validate()

def test_log_likelihood(self):
x = jnp.array(self.data)
ll = self.cc.conditional_log_likelihood(x)

if __name__ == '__main__':
unittest.main()

0 comments on commit 212926a

Please sign in to comment.