Logical autosharding for MPMD#
What you need to know for this tutorial#
A basic knowledge of Jax and numpy
Familiarity with the concepts of tensor sharding and parallelism
The “logical autosharding for SPMD” tutorial
Initialize MultiMesh for Jax#
As before, the first thing we need to is to initialize MultiMesh for Jax.
from jax_plugins.multimesh import init
init(cpus=8)
Import Jax packages#
Once MultiMesh for Jax has been initialized, a standard set of Jax imports can be done.
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.debug import visualize_array_sharding
from multimesh.jax import with_sharding_constraint, parallelize, task, TaskMesh
Assigning task contexts to submeshes in MultiMesh for Jax#
MultiMesh for Jax follows a model of logical autosharding, which slighly changes the usual Jax flow.
Previously we did an SPMD problem, which used a uniform mesh across the whole model. We now add task
annotations that assign different subcomputations to different submeshes. The order of steps is now:
Writing numpy-like code that implements the model
Add logical sharding annotations to parameters in the model
Assign different subcomputations in the model to different submeshes
Ahead-of-time compile the model, which computes the physical shardings
Use the computed physical shardings to initialize parameters
def f(param, x):
return jnp.einsum("bm,mh->bh", x, param)
def model(params, x):
x = with_sharding_constraint(x, P("batch", None))
param_1, param_2 = params
param_1 = with_sharding_constraint(param_1, P("model", None))
param_2 = with_sharding_constraint(param_2, P(None, "model"))
x = f(param_1, x)
x = f(param_2, x)
return with_sharding_constraint(x, P("batch", None))
We now modify the model to assign the two calls to f
to different submeshes.
devices = np.array(jax.devices())
task_mesh = TaskMesh(devices=(0,1,2,3), axis_names=("x","y"), axis_sizes=(4,1))
layer1 = task(f,
mesh=task_mesh,
logical_axes=(
("batch", "x"),
("model", "x"),
))
layer2 = task(f,
mesh=task_mesh.place((4,5,6,7)),
logical_axes=(
("batch", "x"),
("model", "x"),
))
def model(params, x):
x = with_sharding_constraint(x, P("batch", None))
param_1, param_2 = params
param_1 = with_sharding_constraint(param_1, P("model", None))
param_2 = with_sharding_constraint(param_2, P(None, "model"))
x = layer1(param_1, x)
x = layer2(param_2, x)
return with_sharding_constraint(x, P("batch", None))
WARNING:2025-07-17 19:00:31,283:jax._src.xla_bridge:830: Platform 'multimesh' is experimental and not all JAX functionality may be correctly supported!
As before, to run the model, we will need input batches and parameters. We use dummy functions for now:
def get_input_batch():
return jnp.arange(64).reshape(8,8)
def init_params():
param1 = jnp.arange(64).reshape(8,8)
param2 = jnp.arange(64).reshape(8,8)
return param1, param2
We can now create a sharded model using multimesh.jax.parallelize
sharded_model, init_sharded_params, get_sharded_batch = parallelize(
model, init_params=init_params, get_input_batch=get_input_batch
)
2025-07-17 19:00:32.877750: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!
We can then inspect the generated shardings for the parameters.
param_1, param_2 = init_sharded_params()
print("Param 1")
visualize_array_sharding(param_1, use_color=False)
print("Param 2")
visualize_array_sharding(param_2, use_color=False)
Param 1
Param 2
┌───────────────────────┐ │ MULTIMESH 0 │ ├───────────────────────┤ │ MULTIMESH 1 │ ├───────────────────────┤ │ MULTIMESH 2 │ ├───────────────────────┤ │ MULTIMESH 3 │ └───────────────────────┘
┌───────────┬───────────┬───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │MULTIMESH 4│MULTIMESH 5│MULTIMESH 6│MULTIMESH 7│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┴───────────┴───────────┘
Param 1 is sharded row-wise across devices 0-3. Param 2 is sharded column-wise across a different set of devices, showing that layers 1 and layers 2 will run on separate GPUs. We can also inspect the sharding of the input batch
batch = get_sharded_batch()
visualize_array_sharding(batch, use_color=False)
┌───────────────────────┐ │ MULTIMESH 0 │ ├───────────────────────┤ │ MULTIMESH 1 │ ├───────────────────────┤ │ MULTIMESH 2 │ ├───────────────────────┤ │ MULTIMESH 3 │ └───────────────────────┘
Here the input batches are sharded only over devices 0-3, which matches the sharding for the first layer. This makes sense since the input batch is passed to layer1
first before the output is pipelined to layer2
.
Appendix: input batches sharded over all devices#
As in the SPMD example, a dataset loader may generate sharded input batches, e.g. one shard per device. In this case, the input batch could be sharded over all devices 0-7. The first layer (layer1
) expects the input batch to be sharded only across devices 0-3, though. MultiMesh for Jax automatically handles this resharding.
mesh = Mesh(np.array(jax.devices()), ("x",))
input_sharding = NamedSharding(mesh, P("x", None)) # no sharding over model dim
data_loader = jax.jit(lambda: jnp.arange(64).reshape(8,8), out_shardings=input_sharding)
example_batch = data_loader()
visualize_array_sharding(example_batch, use_color=False)
┌───────────────────────┐ │ MULTIMESH 0 │ ├───────────────────────┤ │ MULTIMESH 1 │ ├───────────────────────┤ │ MULTIMESH 2 │ ├───────────────────────┤ │ MULTIMESH 3 │ ├───────────────────────┤ │ MULTIMESH 4 │ ├───────────────────────┤ │ MULTIMESH 5 │ ├───────────────────────┤ │ MULTIMESH 6 │ ├───────────────────────┤ │ MULTIMESH 7 │ └───────────────────────┘
We can now compile an MPMD-sharded model that accepts inputs sharded across all devices
sharded_model, init_sharded_params = parallelize(
model, init_params=init_params, initial_batch=example_batch
)
2025-07-17 19:00:34.531269: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!