Replies: 1 comment
-
After further investigation, I've managed to achieve the behavior I was looking for using import jax
import numpy as np
import jax.numpy as jnp
import functools
from absl import app
from absl import flags
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
flags.DEFINE_string("server_addr", "", help="server ip addr")
flags.DEFINE_integer("num_hosts", 1, help="num of hosts")
flags.DEFINE_integer("host_idx", 0, help="index of current host")
FLAGS = flags.FLAGS
def f(x):
return x
def main(argv):
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
devices = jax.devices()
local_devices = jax.local_devices()
print("host_idx:", FLAGS.host_idx)
print("devices:", devices)
print("local_devices:", local_devices)
mesh = Mesh(np.array(devices), ("i",))
sharding = NamedSharding(mesh, P("i"))
replicated_sharding = NamedSharding(mesh, P())
x = 8 * FLAGS.host_idx + jnp.arange(8)
global_array = multihost_utils.host_local_array_to_global_array(x, mesh, P("i"))
x_s = shard_map(f, mesh, in_specs=P("i"), out_specs=P("i"))(global_array)
print("x:", x)
print(jax.debug.visualize_array_sharding(x))
print("x_s", multihost_utils.process_allgather(x_s))
print(jax.debug.visualize_array_sharding(x_s))
@functools.partial(
shard_map,
mesh=mesh,
in_specs=P("i"),
out_specs=P("i"),
)
def psum_data(data):
return jax.lax.psum(data, "i")
p_sum_out = psum_data(global_array)
print("devices buffers x_s", [shard.data for shard in x_s.addressable_shards])
print(
"devices buffers p_sum_out",
[shard.data for shard in p_sum_out.addressable_shards],
)
print(
"output after taking psum (global_array):",
multihost_utils.process_allgather(p_sum_out),
)
print(
"output after taking psum (local_array):",
multihost_utils.global_array_to_host_local_array(p_sum_out, mesh, P("i")),
)
if __name__ == "__main__":
app.run(main) The key changes that made this work:
While this approach works, I have some additional questions:
Any insights or best practices would be greatly appreciated. Thank you! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone! I'm exploring the differences between
shard_map
andpmap
functionalities in JAX, particularly in a multi-host setting. I've encountered some behavior that I'd like to understand better and potentially find a solution for.Setup: Multi-Host Environment
Consider a setup with 2 hosts, each having 4 devices.
Example 1: Using
pmap
Here's a basic script using
pmap
:With
pmap
, the final output of takingpsum
yields[56 64]
across all shards, as expected.Example 2: Attempting to Use
shard_map
Now, I tried to achieve the same result using
shard_map
:Observed Behavior and Questions
With
shard_map
, I'm getting different outputs for the shards on each host:[12 16]
[44 48]
It seems like the shards are not aware of each other across hosts when using
shard_map
.Questions
shard_map
compatible with the behavior I want, i.e., to perform operations across all devices on all hosts? Or is it better to stick topmap
for this functionality?shard_map
?Any insights or suggestions would be greatly appreciated. Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions