diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index a11bac1..3b2fc09 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -285,7 +285,7 @@ class ProductLayer(InnerLayer): units with the same scope. """ - edges: Int[BCOO, "len(child_layers), number_of_nodes"] + edges: Int[BCOO, "len(child_layers), number_of_nodes"] = eqx.field(static=True) """ The edges consist of a sparse matrix containing integers. The first dimension describes the edges for each child layer. diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index d287f56..a50d2b4 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -7,6 +7,7 @@ from .inner_layer import InputLayer, NXConverterLayer from ..nx.distributions import DiracDeltaDistribution +import equinox as eqx class ContinuousLayer(InputLayer, ABC): @@ -20,7 +21,7 @@ class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC): Abstract class for continuous univariate input units with finite support. """ - interval: jax.Array + interval: jax.Array = eqx.field(static=True) """ The interval of the distribution as a array of shape (num_nodes, 2). The first column contains the lower bounds and the second column the upper bounds. diff --git a/test/test_jax/test_probabilistic_circuit.py b/test/test_jax/test_probabilistic_circuit.py index 5ba4282..07c9c8d 100644 --- a/test/test_jax/test_probabilistic_circuit.py +++ b/test/test_jax/test_probabilistic_circuit.py @@ -1,30 +1,21 @@ import unittest -from probabilistic_model.probabilistic_circuit.jax import SumLayer -from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit, - ProbabilisticCircuit as NXProbabilisticCircuit) -from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution -from random_events.variable import Continuous - -from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit import jax.numpy as jnp -import plotly.graph_objects as go -from random_events.product_algebra import VariableMap, SimpleEvent +import numpy as np +from random_events.variable import Continuous from probabilistic_model.learning.jpt.jpt import JPT from probabilistic_model.learning.jpt.variables import infer_variables_from_dataframe +from probabilistic_model.probabilistic_circuit.jax import SumLayer from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit -import numpy as np -import torch -import jax.numpy as jnp -import tqdm - -from probabilistic_model.probabilistic_model import ProbabilisticModel -from probabilistic_model.utils import timeit +from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit, + ProbabilisticCircuit as NXProbabilisticCircuit) np.random.seed(69) import pandas as pd + class SmallCircuitIntegrationTestCase(unittest.TestCase): x = Continuous("x") y = Continuous("y") @@ -74,8 +65,6 @@ def test_creation(self): sum_layer2 = product_layer.child_layers[1] self.assertEqual(sum_layer1.number_of_nodes, 2) self.assertEqual(sum_layer2.number_of_nodes, 2) - self.assertTrue(jnp.allclose(sum_layer1.variables, jnp.array([1]))) - self.assertTrue(jnp.allclose(sum_layer2.variables, jnp.array([0]))) def test_ll(self): samples = self.nx_model.sample(1000) @@ -86,13 +75,12 @@ def test_ll(self): class JPTIntegrationTestCase(unittest.TestCase): number_of_variables = 2 - number_of_samples= 10000 + number_of_samples = 10000 jpt: NXProbabilisticCircuit @classmethod def setUpClass(cls): - mean = np.full(cls.number_of_variables, 0) cov = np.random.uniform(0, 1, (cls.number_of_variables, cls.number_of_variables)) cov = np.dot(cov, cov.T) @@ -110,7 +98,5 @@ def test_from_jpt(self): self.assertTrue((jax_ll > -jnp.inf).all()) - - if __name__ == '__main__': unittest.main()