MLP pipeline example with task transforms#
What you need to know for this tutorial#
A basic knowledge of Jax and numpy
Familiarity with the concepts of tensor sharding and parallelism
Familiarity with basic neural network concepts like dense layers, activation functions, and backprop
The previous tutorials on SPMD/MPMD autosharding with MultiMesh for Jax
Optional, but helpful, to know#
Flax package for building composable, trainable models in Jax
Optax package with different optimizers
Initialize MultiMesh for Jax#
As in the previous examples, the first step in a MultiMesh for Jax program is to initialize the environment.
from jax_plugins.multimesh import init
init(cpus=8)
Import Jax packages#
Once MultiMesh for Jax has been initialized, the standard set of Jax imports can be done. In this case, we are also importing:
Flax, which is used to define modules with parameters
Optax, which provides optimizers for updating parameters
For more details on Flax and Optax, the user should consult the docs for these packages. For this tutorial, the usage of these packages should be relatively clear. Note: functions that were previously imported from multimesh.jax
must now be imported from multimesh.jax.flax
. Jax requires pure functions while Flax uses an “object-oriented” paradigm of modules with state. Flax often requires special wrappers around standard Jax functions that interoperate with the Flax->Jax functionalization process.
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.sharding import Mesh, PartitionSpec as P
from multimesh.jax.flax import (
parallelize_step,
shard_axes,
task,
with_sharding_constraint
)
from multimesh.jax import TaskMesh
from flax import linen as nn
Create single layer with sharding annotations#
We can now proceed to define a model called Pipeline
. The building block for this model will be a basic MLP layer which is a nn.Dense
followed by a nn.relu
. The dense layer implicitly defines a parameter tensor. Two key observations:
A sharding constraint is defined for the input activations
x
that names the tensor axes. The input activations are 1D with the axis namedbatch
.Sharding is also defined for the 2D tensor in the dense layer. We give these axes the names
data
andmodel
.
class MLPBlock(nn.Module):
features: int = 4
@nn.compact
def __call__(self, x):
x = with_sharding_constraint(x, P("batch"))
x = nn.Dense(features=self.features,kernel_init=shard_axes("data","model"))(x)
x = nn.relu(x)
return x
Create full model with pipelined layers#
Using the MLP building blocks, we can build up a pipeline parallel model. Each layer of the pipeline will be assigned to a different submesh, which is computed in the submesh
method.
For each new layer, the device grid is incremented. An important abstraction here is the axis_map
used in the Pipeline
.
MultiMesh for Jax emphasizes logical autosharding - using descriptive names rather than positional or hardware-specific names. This follows the philosophy used in Levanter outlined in this article - and also followed by maxtext and T5x and their use of logical sharding rules. Flax code is written with logical names (batch
, data
, model
), which then must be bound to physical device axes (x
and y
) to fully define the parallelism. This creates a separation between model specification (logical code written with named tensor axes) and model mapping (physical sharding of axes across a device mesh).
In this particular case, we map the batch
and data
logical dimensions to the physical x
axis and give each submesh a 4x1
shape. MultiMesh for Jax therefore provides the following capabilities:
The ability to assign computations to a submesh within a larger computation. Jax enforces an SPMD requirement where tensors must be sharded across all devices and every device must participate.
Logical sharding rules within a submesh. This provides similar functionality to T5x and Levanter, but allows arbitrary resharding.
class Pipeline(nn.Module):
n_layers: int = 2
shape: tuple[int] = (4,1)
axes = ("x", "y")
axis_map = (
("batch", "x"),
("data", "x"),
("model", "y"),
)
def submesh(self, layer: int) -> Mesh:
submesh_size = np.prod(self.shape)
off = layer * submesh_size
return TaskMesh(devices=range(off,off+submesh_size), axis_names=self.axes, axis_sizes=self.shape)
@nn.compact
def __call__(self, x):
for i in range(self.n_layers):
layer = task(MLPBlock, mesh=self.submesh(i), name=f"layer_{i}", logical_axes=self.axis_map)()
x = layer(x)
return (x*x).sum()
Initialize parameters and create parallel training function#
So far the code looks almost the same as standard Flax code with a few extra annotations. The full power of MultiMesh for Jax comes from its auto-sharding compiler that reads all the submesh annotations and transforms the global computation into a series of submesh tasks. Managing Jax training state and shardings can be difficult when directly calling jax.jit
. The MultiMesh Flax wrappers provides a parallelize_step
for training that allows users to generate MPMD auto-sharded models for Flax modules, similar to the previous MPMD example. The user simply passes in the model
, a standard optax
optimizer, and an example input batch. MultiMesh for Jax does all the work to derive all the shardings, compile the model, and initialize all the sharded training state. parallelize_step
returns back the four objects necessary to run a training loop:
A
step_fn
for computing gradients and updating parametersThe randomly initialized sharded parameters
The mesh context to use for training steps
A function for preprocessing input batches located in host memory into device arrays
This setup is slightly different than the multimesh.jax.parallelize
function, but is more idiomatic to Flax and training.
x = jnp.arange(16)
opt = optax.sgd(learning_rate=0.02)
mesh = Mesh(np.array(jax.devices()).reshape(8,1), ("x","y"))
step_fn, sharded_train_state, mesh, prepare_batch = parallelize_step(
model=Pipeline(), optimizer=opt, mesh=mesh, local_batch=x)
2025-07-17 18:56:20.978883: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!
Run training steps#
Now that we have our initial training state (parameters + optimizer state), we can run training steps to compute the loss and increment parameters.
with mesh:
batch = prepare_batch(x)
loss, sharded_train_state = step_fn(sharded_train_state, batch)
print("Loss = ", loss)
Loss = 64.85133
Deep dive: MPMD sharding specs#
To get a glimpse of the underlying details, we can inspect the sharding specs of the training state. There are 5 tensors:
A single scalar parameter that is not partitioned (empty
PartitionSpec()
). This is the learning rate (optimizer state for SGD).A (4,) tensor replicated across devices 0-3 corresponding to the first layer bias
A (16,4) tensor fully sharded across devices 0-3 corresponding to the fully connected weights in the first layer
A (4,) tensor replicated across devices 4-7 corresponding to the second layer bias
A (4,4) tensor fully sharded across devices 4-7 corresponding to the fully connected weights in the second layer
def inspect(x):
print(x.shape)
if len(x.shape) > 0:
jax.debug.visualize_array_sharding(x, use_color=False)
else:
print("REPLICATED", x.sharding)
jax.tree.map(inspect, sharded_train_state)
pass
()
REPLICATED NamedSharding(mesh=Mesh('x': 8, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=unpinned_host)
(4,)
(16, 4)
(4,)
(4, 4)
┌─────────────────┐ │MULTIMESH 0,1,2,3│ └─────────────────┘
┌───────────┐ │MULTIMESH 0│ ├───────────┤ │MULTIMESH 1│ ├───────────┤ │MULTIMESH 2│ ├───────────┤ │MULTIMESH 3│ └───────────┘
┌─────────────────┐ │MULTIMESH 4,5,6,7│ └─────────────────┘
┌───────────────────────┐ │ MULTIMESH 4 │ ├───────────────────────┤ │ MULTIMESH 5 │ ├───────────────────────┤ │ MULTIMESH 6 │ ├───────────────────────┤ │ MULTIMESH 7 │ └───────────────────────┘