Skip to content

Commit

Permalink
Correct sampling for sum layer but its super slow
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 7, 2024
1 parent 24a4ad2 commit 0357422
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
6 changes: 1 addition & 5 deletions src/probabilistic_model/probabilistic_circuit/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def sample_from_sparse_probabilities(log_probabilities: BCOO, amount: jax.Array,
:param key: The random key.
:return: The samples that are drawn for each state in the probabilities indicies.
"""

all_samples = []

for probability_row, row_amount in zip(log_probabilities, amount):
Expand All @@ -83,11 +82,8 @@ def sample_from_sparse_probabilities(log_probabilities: BCOO, amount: jax.Array,

samples = jax.random.categorical(key, probability_row.data, shape=(row_amount.item(), ))
frequencies = jnp.zeros((probability_row.data.shape[0],), dtype=jnp.int32)
unique, counts = jnp.unique(samples, return_counts=True)

frequencies = frequencies.at[unique].set(counts)
frequencies = frequencies.at[samples].add(1)
all_samples.append(frequencies)

return BCOO((jnp.concatenate(all_samples), log_probabilities.indices), shape=log_probabilities.shape,
indices_sorted=True, unique_indices=True)

1 change: 1 addition & 0 deletions test/test_jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_sample_from_sparse_probabilities(self):
[0.4, 0., 0.6, 0.]]))
probs.data = jnp.log(probs.data)
amount = jnp.array([2, 3])

samples = sample_from_sparse_probabilities(probs,amount, jax.random.PRNGKey(69))
amounts = samples.sum(axis=1).todense()
self.assertTrue(jnp.all(amounts == amount))
Expand Down

0 comments on commit 0357422

Please sign in to comment.