Skip to content

How to make pmap return a collection instead of a concatenation #23137

Answered by jakevdp
sokol11 asked this question in Q&A
Discussion options

You must be logged in to vote

JAX vectorization and parallelization transformations, including vmap, pmap, and shard_map, all work with a struct-of-arrays storage pattern rather than an array-of-structs pattern. This means that if you vmap or pmap over a dict, you get a dict of batched arrays, not a list of dicts. There is no way to make these transformations return an array of structs, but you could take the output and transform it as a post-processing step.

For example, imagine your pmap created this nested dict of arrays with a leading batch dimension of size 4; you could use jax.tree.transpose to convert it to a sequence of dicts:

import jax
import jax.numpy as jnp

result = {'a': jnp.arange(4), 'b': {'c': jnp.ones((

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@sokol11
Comment options

Answer selected by sokol11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants