Logical autosharding for SPMD#
What you need to know for this tutorial#
A basic knowledge of Jax and numpy
Familiarity with the concepts of tensor sharding and parallelism
Initialize MultiMesh for Jax#
The first step in a MultiMesh for Jax program is to initialize the environment. Before any other imports, jax_plugins.multimesh.init
should be called with the appropriate arguments for configuring MultiMesh. This does violate the common Python style rule of top-level imports coming before all code. It is possible to configure MultiMesh ahead-of-time with environment variables, but the programmatic method here is generally cleaner and easier. The parameters for MultiMesh must be configured before importing any Jax functions.
from jax_plugins.multimesh import init
init(cpus=8)
Import Jax packages#
Once MultiMesh for has been initialized, a standard set of Jax imports can be done.
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.debug import visualize_array_sharding
from multimesh.jax import with_sharding_constraint, parallelize, task, TaskMesh
Regular Jax: create layers with explicit SPMD sharding annotations#
In Jax, the usual workflow would be:
Writing numpy-like code that implements the model
Initializing parameters and input batches with explicit shardings
jit
the function with the explicit shardingsExecuting the jit-compiled function with the initialized parameters
We start with step 1 and implement a basic model with two fully-connected layers:
def f(param, x):
return jnp.einsum("bm,mh->bh", x, param)
def model(params, x):
param_1, param_2 = params
x = f(param_1, x)
return f(param_2, x)
To run the model, we will need input batches and parameters. We use dummy functions for now:
def get_input_batch():
return jnp.arange(64).reshape(8,8)
def init_params():
param1 = jnp.arange(64).reshape(8,8)
param2 = jnp.arange(64).reshape(8,8)
return param1, param2
We use NamedSharding
objects to define how we want to shard the parameters and input batches.
We shard over the batch axis for the inputs and over the “x” axis for the parameters. We can then jit-compile
and run the init function to generate the sharded inputs/parameters.
mesh = Mesh(np.array(jax.devices()), ("x",))
input_sharding = NamedSharding(mesh, P("x", None)) # no sharding over model dim
param_sharding = NamedSharding(mesh, P("x", None)) # no sharding over y-axis
get_sharded_batch = jax.jit(get_input_batch, out_shardings=input_sharding)
get_sharded_params = jax.jit(init_params, out_shardings=(param_sharding, param_sharding))
batch = get_sharded_batch()
params = get_sharded_params()
WARNING:2025-07-17 19:01:12,972:jax._src.xla_bridge:830: Platform 'multimesh' is experimental and not all JAX functionality may be correctly supported!
We can inspect the shardings and see a NamedSharding
that creates a 1D slice along the mesh x-axis.
def inspect(x):
print(x.shape, x.sharding)
print("Input")
jax.tree.map(inspect, batch)
print("Params")
jax.tree.map(inspect, params)
visualize_array_sharding(batch)
Input
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x', None), memory_kind=unpinned_host)
Params
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x', None), memory_kind=unpinned_host)
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x', None), memory_kind=unpinned_host)
┌───────────────────────┐ │ MULTIMESH 0 │ ├───────────────────────┤ │ MULTIMESH 1 │ ├───────────────────────┤ │ MULTIMESH 2 │ ├───────────────────────┤ │ MULTIMESH 3 │ ├───────────────────────┤ │ MULTIMESH 4 │ ├───────────────────────┤ │ MULTIMESH 5 │ ├───────────────────────┤ │ MULTIMESH 6 │ ├───────────────────────┤ │ MULTIMESH 7 │ └───────────────────────┘
We can now jit-compile a sharded model and pass it sharded inputs and parameters.
sharded_model = jax.jit(model, in_shardings=((param_sharding, param_sharding), input_sharding))
output = sharded_model(params, batch)
MultiMesh for follows a model of logical autosharding, which slightly changes the usual Jax flow. The order of steps is now:
Writing numpy-like code that implements the model
Add logical sharding annotations to parameters in the model
Ahead-of-time compile the model, which computes the physical shardings
Use the computed physical shardings to initialize parameters
We modify the model code above to now include logical sharding constraints.
def f(param, x):
return jnp.einsum("bm,mh->bh", x, param)
def model(params, x):
x = with_sharding_constraint(x, P("batch", None))
param_1, param_2 = params
param_1 = with_sharding_constraint(param_1, P("model", None))
param_2 = with_sharding_constraint(param_2, P(None, "model"))
x = f(param_1, x)
x = f(param_2, x)
return with_sharding_constraint(x, P("batch", None))
MultiMesh for Jax provides a helper function to simplify the workflow.
The multimesh.jax.parallelize
function takes the get_input_batch
and init_params
functions as inputs. It executes an ahead-of-time compilation with autosharding and returns the get_sharded_batch
and get_sharded_params
functions.
sharded_model, init_sharded_params, get_sharded_batch = parallelize(
model, init_params=init_params, get_input_batch=get_input_batch
)
2025-07-17 19:01:15.646895: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!
We can now generate an initial input batch and initial parameters and inspect the sharding.
batch = get_sharded_batch()
params = init_sharded_params()
print("Input")
jax.tree.map(inspect, batch)
print("Params")
jax.tree.map(inspect, params)
visualize_array_sharding(batch)
Input
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
Params
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
┌─────────────────────────┐ │ │ │ │ │ │ │ │ │MULTIMESH 0,1,2,3,4,5,6,7│ │ │ │ │ │ │ │ │ └─────────────────────────┘
What we see in the debug output and the visualization is replicated sharding - a single block replicated to all CPUs. In the previous compilation, there was no logical task context to define how logical autosharding should be converted to physical shardings. multimesh.jax
provides the task
transformation to create this context. We define a 1-D mesh and map the logical batch
and model
axis to the device x
axis.
task_mesh = TaskMesh(devices=range(0,8), axis_names=("x",), axis_sizes=(8,))
mm_model = task(model, mesh=task_mesh,
logical_axes=(
("batch", "x"),
("model", "x"),
))
sharded_model, init_sharded_params, get_sharded_batch = parallelize(
mm_model, init_params=init_params, get_input_batch=get_input_batch
)
2025-07-17 19:01:16.787629: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!
We can again generate example input batches and parameters and inspect the shardings. We now see that the logical autosharding has created a 1-D sharding.
batch = get_sharded_batch()
params = init_sharded_params()
print("Input")
jax.tree.map(inspect, batch)
print("Params")
jax.tree.map(inspect, params)
visualize_array_sharding(batch)
Input
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x',), memory_kind=unpinned_host)
Params
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x',), memory_kind=unpinned_host)
(8, 8) NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(None, 'x'), memory_kind=unpinned_host)
┌───────────────────────┐ │ MULTIMESH 0 │ ├───────────────────────┤ │ MULTIMESH 1 │ ├───────────────────────┤ │ MULTIMESH 2 │ ├───────────────────────┤ │ MULTIMESH 3 │ ├───────────────────────┤ │ MULTIMESH 4 │ ├───────────────────────┤ │ MULTIMESH 5 │ ├───────────────────────┤ │ MULTIMESH 6 │ ├───────────────────────┤ │ MULTIMESH 7 │ └───────────────────────┘
The example here shows a single SPMD task. Further examples show how this can be extended to MPMD tasks across different submeshes.
Appendix#
Using an example batch instead of a batch generator function#
Dataset loaders may produce input batches that are already sharded. In this case, the example above can be modified slightly. Instead of passing in a batch generator function, an example batch can be passed to parallelize
. In this case, only 2 functions are returned: the sharded model function and the parameter initialization.
input_sharding = NamedSharding(mesh, P("x", None)) # no sharding over model dim
data_loader = jax.jit(lambda: jnp.arange(64).reshape(8,8), out_shardings=input_sharding)
example_batch = data_loader()
sharded_model, init_sharded_params = parallelize(
mm_model, init_params=init_params, initial_batch=example_batch
)
2025-05-01 21:20:21.346157: W ./xla/service/hlo_module_config.h:193] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!
Implementation details of parallelize
#
Jax provides several utilities for abstract evalution of functions to determine output shapes. This is used, e.g. to determine the PyTree of parameter shapes.
param_shapes = jax.eval_shape(init_params)
abstract_params = jax.tree_map(
lambda x: jax.core.ShapedArray(x.shape, x.dtype), param_shapes
)
Instead of explicit input/output shardings, the shardings are set to AUTO
:
param_shardings = jax.tree_map(lambda x: AUTO(mesh), param_shapes)
Once the abstract shapes are known, the function is compiled in a multimesh.jax
autosharding context:
with multimesh.jax.autoshard(True):
jit_f = jax.jit(model,
in_shardings=(param_shardings, batch_sharding),
out_shardings=out_shardings)
compiled = jit_f.lower(abstract_params, initial_batch).compile()
The derived input and output shardings are available from the compiled function, e.g. compiled.input_shardings