Skip to content

Commit

Permalink
Jax fully working for ll. Now investigating the training of a circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Sep 26, 2024
1 parent 114f3a1 commit f57b444
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class ProductLayer(InnerLayer):
units with the same scope.
"""

edges: Int[BCOO, "len(child_layers), number_of_nodes"]
edges: Int[BCOO, "len(child_layers), number_of_nodes"] = eqx.field(static=True)
"""
The edges consist of a sparse matrix containing integers.
The first dimension describes the edges for each child layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .inner_layer import InputLayer, NXConverterLayer
from ..nx.distributions import DiracDeltaDistribution
import equinox as eqx


class ContinuousLayer(InputLayer, ABC):
Expand All @@ -20,7 +21,7 @@ class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
Abstract class for continuous univariate input units with finite support.
"""

interval: jax.Array
interval: jax.Array = eqx.field(static=True)
"""
The interval of the distribution as a array of shape (num_nodes, 2).
The first column contains the lower bounds and the second column the upper bounds.
Expand Down
30 changes: 8 additions & 22 deletions test/test_jax/test_probabilistic_circuit.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,21 @@
import unittest

from probabilistic_model.probabilistic_circuit.jax import SumLayer
from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit,
ProbabilisticCircuit as NXProbabilisticCircuit)
from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution
from random_events.variable import Continuous

from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit
import jax.numpy as jnp
import plotly.graph_objects as go
from random_events.product_algebra import VariableMap, SimpleEvent
import numpy as np
from random_events.variable import Continuous

from probabilistic_model.learning.jpt.jpt import JPT
from probabilistic_model.learning.jpt.variables import infer_variables_from_dataframe
from probabilistic_model.probabilistic_circuit.jax import SumLayer
from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit
import numpy as np
import torch
import jax.numpy as jnp
import tqdm

from probabilistic_model.probabilistic_model import ProbabilisticModel
from probabilistic_model.utils import timeit
from probabilistic_model.probabilistic_circuit.nx.distributions.distributions import DiracDeltaDistribution
from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import (SumUnit, ProductUnit,
ProbabilisticCircuit as NXProbabilisticCircuit)

np.random.seed(69)
import pandas as pd


class SmallCircuitIntegrationTestCase(unittest.TestCase):
x = Continuous("x")
y = Continuous("y")
Expand Down Expand Up @@ -74,8 +65,6 @@ def test_creation(self):
sum_layer2 = product_layer.child_layers[1]
self.assertEqual(sum_layer1.number_of_nodes, 2)
self.assertEqual(sum_layer2.number_of_nodes, 2)
self.assertTrue(jnp.allclose(sum_layer1.variables, jnp.array([1])))
self.assertTrue(jnp.allclose(sum_layer2.variables, jnp.array([0])))

def test_ll(self):
samples = self.nx_model.sample(1000)
Expand All @@ -86,13 +75,12 @@ def test_ll(self):

class JPTIntegrationTestCase(unittest.TestCase):
number_of_variables = 2
number_of_samples= 10000
number_of_samples = 10000

jpt: NXProbabilisticCircuit

@classmethod
def setUpClass(cls):

mean = np.full(cls.number_of_variables, 0)
cov = np.random.uniform(0, 1, (cls.number_of_variables, cls.number_of_variables))
cov = np.dot(cov, cov.T)
Expand All @@ -110,7 +98,5 @@ def test_from_jpt(self):
self.assertTrue((jax_ll > -jnp.inf).all())




if __name__ == '__main__':
unittest.main()

0 comments on commit f57b444

Please sign in to comment.