Skip to content

Commit

Permalink
switch to blackjax run_inference_algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Oct 7, 2024
1 parent eeb99a8 commit b34b4cc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 91 deletions.
79 changes: 43 additions & 36 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,27 +278,22 @@ def map_fn(x):
assert draws % num_chunks == 0
nsteps = draws // num_chunks

# Run adaptation
adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
progress_bar=progressbar,
**nuts_kwargs,
)

# Run adaptation for sampling parameters
@map_fn
def run_adaptation(seed, init_position):
return adapt.run(seed, init_position, num_steps=tune)

(last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points)

def _one_step(state, x, kernel):
del x
state, rng_key = state
key, _skey = jax.random.split(rng_key)
state, info = kernel(_skey, state)
return blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
progress_bar=progressbar,
**nuts_kwargs,
).run(seed, init_position, num_steps=tune)

(adapt_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points)

# Filters output from each sampling step
def _transform_fn(state, info):
position = state.position
stats = {
"diverging": info.is_divergent,
Expand All @@ -308,42 +303,54 @@ def _one_step(state, x, kernel):
"acceptance_rate": info.acceptance_rate,
"lp": state.logdensity,
}
return (state, key), (position, stats)
return position, stats

# Performs sampling for each chunk
# random keys are carried with state
@map_fn
@partial(jax.jit, donate_argnums=0)
def _multi_step(state, imm, ss):
start_state, key = state
scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar)

kernel = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step

(last_state, key), (raw_samples, stats) = scan_fn(
partial(_one_step, kernel=kernel), (start_state, key), jnp.arange(nsteps)
state, key = state
key, _skey = jax.random.split(key)
last_state, (raw_samples, stats) = blackjax.util.run_inference_algorithm(
_skey,
algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss),
num_steps=nsteps,
initial_state=state,
progress_bar=progressbar,
transform=_transform_fn,
)
samples, log_likelihoods = postprocess_fn(raw_samples)
return (last_state, key), ((samples, log_likelihoods), stats)

sample_fn = partial(
chunk_sample_fn = partial(
_multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"]
)

if progressbar:
logger.info("Sampling chunk %d of %d:" % (1, num_chunks))
(last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed))

# Sample first chunk
last_state, sample_data = chunk_sample_fn((adapt_state, sample_seed))

# If single chunk sampling return results on device
if num_chunks == 1:
return samples[0], stats, samples[1], blackjax
((samples, log_likelihoods), stats) = sample_data
return samples, stats, log_likelihoods, blackjax

# Provision space for all samples on the cpu + save first chunk
output = _set_tree(

Check warning on line 342 in pymc/sampling/jax.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/jax.py#L342

Added line #L342 was not covered by tests
jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)),
jax.device_put((samples, stats), jax.devices("cpu")[0]),
jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), sample_data),
jax.device_put(sample_data, jax.devices("cpu")[0]),
0,
)
del samples, stats
del sample_data

Check warning on line 347 in pymc/sampling/jax.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/jax.py#L347

Added line #L347 was not covered by tests

last_state, (all_samples, all_stats) = _do_chunked_sampling(
(last_state, seed), output, num_chunks, nsteps, sample_fn, progressbar
# Sample remaining chunks
_, ((samples, log_likelihoods), stats) = _do_chunked_sampling(

Check warning on line 350 in pymc/sampling/jax.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/jax.py#L350

Added line #L350 was not covered by tests
last_state, output, num_chunks, nsteps, chunk_sample_fn, progressbar
)
return all_samples[0], all_stats, all_samples[1], blackjax
return samples, stats, log_likelihoods, blackjax

Check warning on line 353 in pymc/sampling/jax.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/jax.py#L353

Added line #L353 was not covered by tests


def _numpyro_stats_to_dict(posterior):
Expand Down
57 changes: 2 additions & 55 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()


@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"])
def test_external_nuts_chunking(nuts_sampler):
def test_numpyro_external_nuts_chunking():
# chunked sampling should give exact same results as non-chunked
nuts_sampler = "numpyro"
pytest.importorskip(nuts_sampler)

with Model():
Expand All @@ -104,56 +104,3 @@ def test_external_nuts_chunking(nuts_sampler):
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L)
assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys()


def test_step_args():
with Model() as model:
a = Normal("a")
idata = sample(
nuts_sampler="numpyro",
target_accept=0.5,
nuts={"max_treedepth": 10},
random_seed=1410,
)

npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)


@pytest.mark.skipif(jax.default_backend() == "cpu", reason="need default backend that is not cpu")
@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"])
def test_postprocessing_backend(nuts_sampler):
pytest.importorskip(nuts_sampler)
default_backend = jax.default_backend()

with Model():
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])

Normal("L", mu=x, sigma=0.1, observed=y)

base_kwargs = dict(
nuts_sampler=nuts_sampler,
random_seed=123,
chains=4,
tune=200,
draws=200,
progressbar=False,
initvals={"x": 0.0},
idata_kwargs={"log_likelihood": True},
)

idata1 = sample(
**base_kwargs,
nuts_sampler_kwargs={
"postprocessing_backend": default_backend,
"chain_method": "vectorized",
},
)
idata2 = sample(
**base_kwargs,
nuts_sampler_kwargs={"postprocessing_backend": "cpu", "chain_method": "vectorized"},
)

np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L)
assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys()

0 comments on commit b34b4cc

Please sign in to comment.