Skip to content

Commit

Permalink
Conversion from nx to jax works
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Sep 26, 2024
1 parent d183593 commit 114f3a1
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/probabilistic_model/probabilistic_circuit/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .inner_layer import *
from .input_layer import *
from .uniform_layer import *
12 changes: 6 additions & 6 deletions src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def variables(self) -> jax.Array:
class SumLayer(InnerLayer):

log_weights: List[BCOO]
child_layers: List[Layer]
child_layers: Union[List[[ProductLayer]], List[InputLayer]]

def __init__(self, child_layers: List[Layer], log_weights: List[BCOO]):
super().__init__(child_layers)
Expand Down Expand Up @@ -241,12 +241,13 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit],
NXConverterLayer:

result_hash_remap = {hash(node): index for index, node in enumerate(nodes)}
variables = tuple(nodes[0].variables)
variables = jnp.array([nodes[0].probabilistic_circuit.variables.index(variable) for variable in nodes[0].variables])

number_of_nodes = len(nodes)

# filter the child layers to only contain layers with the same scope as this one
filtered_child_layers = [child_layer for child_layer in child_layers if tuple(child_layer.layer.variables) ==
variables]
filtered_child_layers = [child_layer for child_layer in child_layers if (child_layer.layer.variables ==
variables).all()]
log_weights = []

# for every possible child layer
Expand Down Expand Up @@ -364,8 +365,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Probabilis

# for every child layer
for child_layer_index, child_layer in enumerate(child_layers):

cl_variables = SortedSet(child_layer.layer.variables)
cl_variables = SortedSet([node.probabilistic_circuit.variables[index] for index in child_layer.layer.variables])

# for every subcircuit
for subcircuit_index, subcircuit in enumerate(node.subcircuits):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from .inner_layer import InputLayer, NXConverterLayer
from ..nx.distributions import DiracDeltaDistribution
from ..nx.probabilistic_circuit import ProbabilisticCircuitMixin


class ContinuousLayer(InputLayer, ABC):
Expand Down Expand Up @@ -101,7 +100,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[DiracDelta
progress_bar: bool = True) -> \
NXConverterLayer:
hash_remap = {hash(node): index for index, node in enumerate(nodes)}
locations = jnp.array([node.location for node in nodes], dtype=jnp.double)
density_caps = jnp.array([node.density_cap for node in nodes], dtype=jnp.double)
locations = jnp.array([node.location for node in nodes], dtype=jnp.float32)
density_caps = jnp.array([node.density_cap for node in nodes], dtype=jnp.float32)
result = cls(nodes[0].probabilistic_circuit.variables.index(nodes[0].variable), locations, density_caps)
return NXConverterLayer(result, nodes, hash_remap)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, variables: SortedSet, root: Layer):
self.variables = variables
self.root = root

def log_likelihood_(self, x: jax.Array) -> jax.Array:
def log_likelihood(self, x: jax.Array) -> jax.Array:
return self.root.log_likelihood_of_nodes(x)[:, 0]

@classmethod
Expand All @@ -50,11 +50,13 @@ def from_nx(cls, pc: NXProbabilisticCircuit, progress_bar: bool = False) -> Prob
# group nodes by depth
layer_to_nodes_map = {depth: [node for node, n_depth in node_to_depth_map.items() if depth == n_depth] for depth
in set(node_to_depth_map.values())}
reversed_layers_to_nodes_map = dict(reversed(layer_to_nodes_map.items()))

# create layers from nodes
child_layers: List[NXConverterLayer] = []
for layer_index, nodes in reversed(tqdm.tqdm(layer_to_nodes_map.items(), desc="Creating Layers") if progress_bar
else layer_to_nodes_map.items()):
for layer_index, nodes in (tqdm.tqdm(reversed_layers_to_nodes_map.items(), desc="Creating Layers") if progress_bar
else reversed_layers_to_nodes_map.items()):

child_layers = Layer.create_layers_from_nodes(nodes, child_layers, progress_bar)
root = child_layers[0].layer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .inner_layer import NXConverterLayer
from .input_layer import ContinuousLayerWithFiniteSupport
from ..nx.distributions import UniformDistribution
from ..nx.probabilistic_circuit import ProbabilisticCircuitMixin
from .utils import simple_interval_to_open_array
import tqdm

Expand Down Expand Up @@ -45,7 +44,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UniformDis

variable = nodes[0].variable

intervals = jnp.hstack([simple_interval_to_open_array(node.interval) for node in
intervals = jnp.vstack([simple_interval_to_open_array(node.interval) for node in
(tqdm.tqdm(nodes, desc=f"Creating uniform layer for variable {variable.name}")
if progress_bar else nodes)])

Expand Down
114 changes: 111 additions & 3 deletions test/test_jax/test_probabilistic_circuit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,116 @@
import unittest

class SMallCircuitIntegrationTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, False) # add assertion here
from probabilistic_model.probabilistic_circuit.jax import SumLayer
from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit,
ProbabilisticCircuit as NXProbabilisticCircuit)
from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution
from random_events.variable import Continuous

from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit
import jax.numpy as jnp
import plotly.graph_objects as go
from random_events.product_algebra import VariableMap, SimpleEvent

from probabilistic_model.learning.jpt.jpt import JPT
from probabilistic_model.learning.jpt.variables import infer_variables_from_dataframe
from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit
import numpy as np
import torch
import jax.numpy as jnp
import tqdm

from probabilistic_model.probabilistic_model import ProbabilisticModel
from probabilistic_model.utils import timeit

np.random.seed(69)
import pandas as pd

class SmallCircuitIntegrationTestCase(unittest.TestCase):
x = Continuous("x")
y = Continuous("y")

nx_model = SumUnit()
jax_model: ProbabilisticCircuit
nx_model: NXProbabilisticCircuit

@classmethod
def setUpClass(cls):
sum1, sum2, sum3 = SumUnit(), SumUnit(), SumUnit()
sum4, sum5 = SumUnit(), SumUnit()
prod1, prod2 = ProductUnit(), ProductUnit()

sum1.add_subcircuit(prod1, 0.5)
sum1.add_subcircuit(prod2, 0.5)
prod1.add_subcircuit(sum2)
prod1.add_subcircuit(sum4)
prod2.add_subcircuit(sum3)
prod2.add_subcircuit(sum5)

d_x1, d_x2 = DiracDeltaDistribution(cls.x, 0, 1), DiracDeltaDistribution(cls.x, 1, 2)
d_y1, d_y2 = DiracDeltaDistribution(cls.y, 2, 3), DiracDeltaDistribution(cls.y, 3, 4)

sum2.add_subcircuit(d_x1, 0.8)
sum2.add_subcircuit(d_x2, 0.2)
sum3.add_subcircuit(d_x1, 0.7)
sum3.add_subcircuit(d_x2, 0.3)

sum4.add_subcircuit(d_y1, 0.5)
sum4.add_subcircuit(d_y2, 0.5)
sum5.add_subcircuit(d_y1, 0.1)
sum5.add_subcircuit(d_y2, 0.9)

cls.nx_model = sum1.probabilistic_circuit
cls.jax_model = ProbabilisticCircuit.from_nx(cls.nx_model)

def test_creation(self):
self.assertEqual(self.jax_model.variables, self.nx_model.variables)
self.assertIsInstance(self.jax_model.root, SumLayer)
self.assertEqual(self.jax_model.root.number_of_nodes, 1)
self.assertEqual(len(self.jax_model.root.child_layers), 1)
product_layer = self.jax_model.root.child_layers[0]
self.assertEqual(product_layer.number_of_nodes, 2)
self.assertEqual(len(product_layer.child_layers), 2)
sum_layer1 = product_layer.child_layers[0]
sum_layer2 = product_layer.child_layers[1]
self.assertEqual(sum_layer1.number_of_nodes, 2)
self.assertEqual(sum_layer2.number_of_nodes, 2)
self.assertTrue(jnp.allclose(sum_layer1.variables, jnp.array([1])))
self.assertTrue(jnp.allclose(sum_layer2.variables, jnp.array([0])))

def test_ll(self):
samples = self.nx_model.sample(1000)
nx_ll = self.nx_model.log_likelihood(samples)
jax_ll = self.jax_model.log_likelihood(samples)
self.assertTrue(jnp.allclose(nx_ll, jax_ll))


class JPTIntegrationTestCase(unittest.TestCase):
number_of_variables = 2
number_of_samples= 10000

jpt: NXProbabilisticCircuit

@classmethod
def setUpClass(cls):

mean = np.full(cls.number_of_variables, 0)
cov = np.random.uniform(0, 1, (cls.number_of_variables, cls.number_of_variables))
cov = np.dot(cov, cov.T)
samples = np.random.multivariate_normal(mean, cov, cls.number_of_samples)
df = pd.DataFrame(samples, columns=[f"x_{i}" for i in range(cls.number_of_variables)])
variables = infer_variables_from_dataframe(df, min_samples_per_quantile=100)
jpt = JPT(variables, min_samples_leaf=0.1)
jpt.fit(df)
cls.jpt = jpt.probabilistic_circuit

def test_from_jpt(self):
model = ProbabilisticCircuit.from_nx(self.jpt, False)
samples = jnp.array(self.jpt.sample(1000))
jax_ll = model.log_likelihood(samples)
self.assertTrue((jax_ll > -jnp.inf).all())




if __name__ == '__main__':
unittest.main()
11 changes: 10 additions & 1 deletion test/test_jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from jax.experimental.sparse import BCOO
import jax.numpy as jnp

from probabilistic_model.probabilistic_circuit.jax.utils import copy_bcoo
from probabilistic_model.probabilistic_circuit.jax.utils import copy_bcoo, simple_interval_to_open_array
from random_events.interval import SimpleInterval

class BCOOTestCase(unittest.TestCase):

Expand All @@ -18,5 +19,13 @@ def test_copy(self):
self.assertFalse(jnp.allclose(x.todense(), y.todense()))


class IntervalConversionTestCase(unittest.TestCase):

def simple_interval_to_open_array(self):
simple_interval = SimpleInterval(0, 1)
array = simple_interval_to_open_array(simple_interval)
self.assertTrue(jnp.allclose(array, jnp.array([0, 1])))


if __name__ == '__main__':
unittest.main()

0 comments on commit 114f3a1

Please sign in to comment.