diff --git a/src/probabilistic_model/probabilistic_circuit/jax/__init__.py b/src/probabilistic_model/probabilistic_circuit/jax/__init__.py index e69de29..530f648 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/__init__.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/__init__.py @@ -0,0 +1,3 @@ +from .inner_layer import * +from .input_layer import * +from .uniform_layer import * diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index ee517b8..a11bac1 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -161,7 +161,7 @@ def variables(self) -> jax.Array: class SumLayer(InnerLayer): log_weights: List[BCOO] - child_layers: List[Layer] + child_layers: Union[List[[ProductLayer]], List[InputLayer]] def __init__(self, child_layers: List[Layer], log_weights: List[BCOO]): super().__init__(child_layers) @@ -241,12 +241,13 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit], NXConverterLayer: result_hash_remap = {hash(node): index for index, node in enumerate(nodes)} - variables = tuple(nodes[0].variables) + variables = jnp.array([nodes[0].probabilistic_circuit.variables.index(variable) for variable in nodes[0].variables]) + number_of_nodes = len(nodes) # filter the child layers to only contain layers with the same scope as this one - filtered_child_layers = [child_layer for child_layer in child_layers if tuple(child_layer.layer.variables) == - variables] + filtered_child_layers = [child_layer for child_layer in child_layers if (child_layer.layer.variables == + variables).all()] log_weights = [] # for every possible child layer @@ -364,8 +365,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Probabilis # for every child layer for child_layer_index, child_layer in enumerate(child_layers): - - cl_variables = SortedSet(child_layer.layer.variables) + cl_variables = SortedSet([node.probabilistic_circuit.variables[index] for index in child_layer.layer.variables]) # for every subcircuit for subcircuit_index, subcircuit in enumerate(node.subcircuits): diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index 9e19572..d287f56 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -7,7 +7,6 @@ from .inner_layer import InputLayer, NXConverterLayer from ..nx.distributions import DiracDeltaDistribution -from ..nx.probabilistic_circuit import ProbabilisticCircuitMixin class ContinuousLayer(InputLayer, ABC): @@ -101,7 +100,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[DiracDelta progress_bar: bool = True) -> \ NXConverterLayer: hash_remap = {hash(node): index for index, node in enumerate(nodes)} - locations = jnp.array([node.location for node in nodes], dtype=jnp.double) - density_caps = jnp.array([node.density_cap for node in nodes], dtype=jnp.double) + locations = jnp.array([node.location for node in nodes], dtype=jnp.float32) + density_caps = jnp.array([node.density_cap for node in nodes], dtype=jnp.float32) result = cls(nodes[0].probabilistic_circuit.variables.index(nodes[0].variable), locations, density_caps) return NXConverterLayer(result, nodes, hash_remap) \ No newline at end of file diff --git a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py index 1a17d7b..c1d24f9 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/probabilistic_circuit.py @@ -30,7 +30,7 @@ def __init__(self, variables: SortedSet, root: Layer): self.variables = variables self.root = root - def log_likelihood_(self, x: jax.Array) -> jax.Array: + def log_likelihood(self, x: jax.Array) -> jax.Array: return self.root.log_likelihood_of_nodes(x)[:, 0] @classmethod @@ -50,11 +50,13 @@ def from_nx(cls, pc: NXProbabilisticCircuit, progress_bar: bool = False) -> Prob # group nodes by depth layer_to_nodes_map = {depth: [node for node, n_depth in node_to_depth_map.items() if depth == n_depth] for depth in set(node_to_depth_map.values())} + reversed_layers_to_nodes_map = dict(reversed(layer_to_nodes_map.items())) # create layers from nodes child_layers: List[NXConverterLayer] = [] - for layer_index, nodes in reversed(tqdm.tqdm(layer_to_nodes_map.items(), desc="Creating Layers") if progress_bar - else layer_to_nodes_map.items()): + for layer_index, nodes in (tqdm.tqdm(reversed_layers_to_nodes_map.items(), desc="Creating Layers") if progress_bar + else reversed_layers_to_nodes_map.items()): + child_layers = Layer.create_layers_from_nodes(nodes, child_layers, progress_bar) root = child_layers[0].layer diff --git a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py index 08fb5cc..fb455f2 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/uniform_layer.py @@ -6,7 +6,6 @@ from .inner_layer import NXConverterLayer from .input_layer import ContinuousLayerWithFiniteSupport from ..nx.distributions import UniformDistribution -from ..nx.probabilistic_circuit import ProbabilisticCircuitMixin from .utils import simple_interval_to_open_array import tqdm @@ -45,7 +44,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UniformDis variable = nodes[0].variable - intervals = jnp.hstack([simple_interval_to_open_array(node.interval) for node in + intervals = jnp.vstack([simple_interval_to_open_array(node.interval) for node in (tqdm.tqdm(nodes, desc=f"Creating uniform layer for variable {variable.name}") if progress_bar else nodes)]) diff --git a/test/test_jax/test_probabilistic_circuit.py b/test/test_jax/test_probabilistic_circuit.py index 5280c0c..5ba4282 100644 --- a/test/test_jax/test_probabilistic_circuit.py +++ b/test/test_jax/test_probabilistic_circuit.py @@ -1,8 +1,116 @@ import unittest -class SMallCircuitIntegrationTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) # add assertion here +from probabilistic_model.probabilistic_circuit.jax import SumLayer +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit, + ProbabilisticCircuit as NXProbabilisticCircuit) +from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution +from random_events.variable import Continuous + +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit +import jax.numpy as jnp +import plotly.graph_objects as go +from random_events.product_algebra import VariableMap, SimpleEvent + +from probabilistic_model.learning.jpt.jpt import JPT +from probabilistic_model.learning.jpt.variables import infer_variables_from_dataframe +from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit +import numpy as np +import torch +import jax.numpy as jnp +import tqdm + +from probabilistic_model.probabilistic_model import ProbabilisticModel +from probabilistic_model.utils import timeit + +np.random.seed(69) +import pandas as pd + +class SmallCircuitIntegrationTestCase(unittest.TestCase): + x = Continuous("x") + y = Continuous("y") + + nx_model = SumUnit() + jax_model: ProbabilisticCircuit + nx_model: NXProbabilisticCircuit + + @classmethod + def setUpClass(cls): + sum1, sum2, sum3 = SumUnit(), SumUnit(), SumUnit() + sum4, sum5 = SumUnit(), SumUnit() + prod1, prod2 = ProductUnit(), ProductUnit() + + sum1.add_subcircuit(prod1, 0.5) + sum1.add_subcircuit(prod2, 0.5) + prod1.add_subcircuit(sum2) + prod1.add_subcircuit(sum4) + prod2.add_subcircuit(sum3) + prod2.add_subcircuit(sum5) + + d_x1, d_x2 = DiracDeltaDistribution(cls.x, 0, 1), DiracDeltaDistribution(cls.x, 1, 2) + d_y1, d_y2 = DiracDeltaDistribution(cls.y, 2, 3), DiracDeltaDistribution(cls.y, 3, 4) + + sum2.add_subcircuit(d_x1, 0.8) + sum2.add_subcircuit(d_x2, 0.2) + sum3.add_subcircuit(d_x1, 0.7) + sum3.add_subcircuit(d_x2, 0.3) + + sum4.add_subcircuit(d_y1, 0.5) + sum4.add_subcircuit(d_y2, 0.5) + sum5.add_subcircuit(d_y1, 0.1) + sum5.add_subcircuit(d_y2, 0.9) + + cls.nx_model = sum1.probabilistic_circuit + cls.jax_model = ProbabilisticCircuit.from_nx(cls.nx_model) + + def test_creation(self): + self.assertEqual(self.jax_model.variables, self.nx_model.variables) + self.assertIsInstance(self.jax_model.root, SumLayer) + self.assertEqual(self.jax_model.root.number_of_nodes, 1) + self.assertEqual(len(self.jax_model.root.child_layers), 1) + product_layer = self.jax_model.root.child_layers[0] + self.assertEqual(product_layer.number_of_nodes, 2) + self.assertEqual(len(product_layer.child_layers), 2) + sum_layer1 = product_layer.child_layers[0] + sum_layer2 = product_layer.child_layers[1] + self.assertEqual(sum_layer1.number_of_nodes, 2) + self.assertEqual(sum_layer2.number_of_nodes, 2) + self.assertTrue(jnp.allclose(sum_layer1.variables, jnp.array([1]))) + self.assertTrue(jnp.allclose(sum_layer2.variables, jnp.array([0]))) + + def test_ll(self): + samples = self.nx_model.sample(1000) + nx_ll = self.nx_model.log_likelihood(samples) + jax_ll = self.jax_model.log_likelihood(samples) + self.assertTrue(jnp.allclose(nx_ll, jax_ll)) + + +class JPTIntegrationTestCase(unittest.TestCase): + number_of_variables = 2 + number_of_samples= 10000 + + jpt: NXProbabilisticCircuit + + @classmethod + def setUpClass(cls): + + mean = np.full(cls.number_of_variables, 0) + cov = np.random.uniform(0, 1, (cls.number_of_variables, cls.number_of_variables)) + cov = np.dot(cov, cov.T) + samples = np.random.multivariate_normal(mean, cov, cls.number_of_samples) + df = pd.DataFrame(samples, columns=[f"x_{i}" for i in range(cls.number_of_variables)]) + variables = infer_variables_from_dataframe(df, min_samples_per_quantile=100) + jpt = JPT(variables, min_samples_leaf=0.1) + jpt.fit(df) + cls.jpt = jpt.probabilistic_circuit + + def test_from_jpt(self): + model = ProbabilisticCircuit.from_nx(self.jpt, False) + samples = jnp.array(self.jpt.sample(1000)) + jax_ll = model.log_likelihood(samples) + self.assertTrue((jax_ll > -jnp.inf).all()) + + + if __name__ == '__main__': unittest.main() diff --git a/test/test_jax/test_utils.py b/test/test_jax/test_utils.py index 4d40f98..b93ac32 100644 --- a/test/test_jax/test_utils.py +++ b/test/test_jax/test_utils.py @@ -2,7 +2,8 @@ from jax.experimental.sparse import BCOO import jax.numpy as jnp -from probabilistic_model.probabilistic_circuit.jax.utils import copy_bcoo +from probabilistic_model.probabilistic_circuit.jax.utils import copy_bcoo, simple_interval_to_open_array +from random_events.interval import SimpleInterval class BCOOTestCase(unittest.TestCase): @@ -18,5 +19,13 @@ def test_copy(self): self.assertFalse(jnp.allclose(x.todense(), y.todense())) +class IntervalConversionTestCase(unittest.TestCase): + + def simple_interval_to_open_array(self): + simple_interval = SimpleInterval(0, 1) + array = simple_interval_to_open_array(simple_interval) + self.assertTrue(jnp.allclose(array, jnp.array([0, 1]))) + + if __name__ == '__main__': unittest.main()