MLP pipeline example with task transforms#

What you need to know for this tutorial#

  1. A basic knowledge of Jax and numpy

  2. Familiarity with the concepts of tensor sharding and parallelism

  3. Familiarity with basic neural network concepts like dense layers, activation functions, and backprop

  4. The previous tutorials on SPMD/MPMD autosharding with MultiMesh for Jax

Optional, but helpful, to know#

  1. Flax package for building composable, trainable models in Jax

  2. Optax package with different optimizers

Initialize MultiMesh for Jax#

As in the previous examples, the first step in a MultiMesh for Jax program is to initialize the environment.

from jax_plugins.multimesh import init

init(cpus=8)

Import Jax packages#

Once MultiMesh for Jax has been initialized, the standard set of Jax imports can be done. In this case, we are also importing:

  1. Flax, which is used to define modules with parameters

  2. Optax, which provides optimizers for updating parameters

For more details on Flax and Optax, the user should consult the docs for these packages. For this tutorial, the usage of these packages should be relatively clear. Note: functions that were previously imported from multimesh.jax must now be imported from multimesh.jax.flax. Jax requires pure functions while Flax uses an “object-oriented” paradigm of modules with state. Flax often requires special wrappers around standard Jax functions that interoperate with the Flax->Jax functionalization process.

import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.sharding import Mesh, PartitionSpec as P
from multimesh.jax.flax import (
    parallelize_step, 
    shard_axes, 
    task, 
    with_sharding_constraint
)
from multimesh.jax import TaskMesh
from flax import linen as nn

Create single layer with sharding annotations#

We can now proceed to define a model called Pipeline. The building block for this model will be a basic MLP layer which is a nn.Dense followed by a nn.relu. The dense layer implicitly defines a parameter tensor. Two key observations:

  1. A sharding constraint is defined for the input activations x that names the tensor axes. The input activations are 1D with the axis named batch.

  2. Sharding is also defined for the 2D tensor in the dense layer. We give these axes the names data and model.

class MLPBlock(nn.Module):
    features: int = 4
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, P("batch"))
        x = nn.Dense(features=self.features,kernel_init=shard_axes("data","model"))(x)
        x = nn.relu(x)
        return x

Create full model with pipelined layers#

Using the MLP building blocks, we can build up a pipeline parallel model. Each layer of the pipeline will be assigned to a different submesh, which is computed in the submesh method. For each new layer, the device grid is incremented. An important abstraction here is the axis_map used in the Pipeline. MultiMesh for Jax emphasizes logical autosharding - using descriptive names rather than positional or hardware-specific names. This follows the philosophy used in Levanter outlined in this article - and also followed by maxtext and T5x and their use of logical sharding rules. Flax code is written with logical names (batch, data, model), which then must be bound to physical device axes (x and y) to fully define the parallelism. This creates a separation between model specification (logical code written with named tensor axes) and model mapping (physical sharding of axes across a device mesh).

In this particular case, we map the batch and data logical dimensions to the physical x axis and give each submesh a 4x1 shape. MultiMesh for Jax therefore provides the following capabilities:

  1. The ability to assign computations to a submesh within a larger computation. Jax enforces an SPMD requirement where tensors must be sharded across all devices and every device must participate.

  2. Logical sharding rules within a submesh. This provides similar functionality to T5x and Levanter, but allows arbitrary resharding.

class Pipeline(nn.Module):
    n_layers: int = 2
    shape: tuple[int] = (4,1)
    axes = ("x", "y")
    axis_map = (
      ("batch", "x"),
      ("data", "x"),
      ("model", "y"),
    )

    def submesh(self, layer: int) -> Mesh:
        submesh_size = np.prod(self.shape)
        off = layer * submesh_size
        return TaskMesh(devices=range(off,off+submesh_size), axis_names=self.axes, axis_sizes=self.shape)

    @nn.compact
    def __call__(self, x):
        for i in range(self.n_layers):
            
            layer = task(MLPBlock, mesh=self.submesh(i), name=f"layer_{i}", logical_axes=self.axis_map)()
            x = layer(x)

        return (x*x).sum()

Initialize parameters and create parallel training function#

So far the code looks almost the same as standard Flax code with a few extra annotations. The full power of MultiMesh for Jax comes from its auto-sharding compiler that reads all the submesh annotations and transforms the global computation into a series of submesh tasks. Managing Jax training state and shardings can be difficult when directly calling jax.jit. The MultiMesh Flax wrappers provides a parallelize_step for training that allows users to generate MPMD auto-sharded models for Flax modules, similar to the previous MPMD example. The user simply passes in the model, a standard optax optimizer, and an example input batch. MultiMesh for Jax does all the work to derive all the shardings, compile the model, and initialize all the sharded training state. parallelize_step returns back the four objects necessary to run a training loop:

  1. A step_fn for computing gradients and updating parameters

  2. The randomly initialized sharded parameters

  3. The mesh context to use for training steps

  4. A function for preprocessing input batches located in host memory into device arrays

This setup is slightly different than the multimesh.jax.parallelize function, but is more idiomatic to Flax and training.

x = jnp.arange(16)
opt = optax.sgd(learning_rate=0.02)
mesh = Mesh(np.array(jax.devices()).reshape(8,1), ("x","y"))
step_fn, sharded_train_state, mesh, prepare_batch = parallelize_step(
    model=Pipeline(), optimizer=opt, mesh=mesh, local_batch=x)
2025-07-17 18:56:20.978883: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!

Run training steps#

Now that we have our initial training state (parameters + optimizer state), we can run training steps to compute the loss and increment parameters.

with mesh:
    batch = prepare_batch(x)
    loss, sharded_train_state = step_fn(sharded_train_state, batch)
    print("Loss = ", loss)
Loss =  64.85133

Deep dive: MPMD sharding specs#

To get a glimpse of the underlying details, we can inspect the sharding specs of the training state. There are 5 tensors:

  1. A single scalar parameter that is not partitioned (empty PartitionSpec()). This is the learning rate (optimizer state for SGD).

  2. A (4,) tensor replicated across devices 0-3 corresponding to the first layer bias

  3. A (16,4) tensor fully sharded across devices 0-3 corresponding to the fully connected weights in the first layer

  4. A (4,) tensor replicated across devices 4-7 corresponding to the second layer bias

  5. A (4,4) tensor fully sharded across devices 4-7 corresponding to the fully connected weights in the second layer

def inspect(x):
    print(x.shape)
    if len(x.shape) > 0:
        jax.debug.visualize_array_sharding(x, use_color=False)
    else:
        print("REPLICATED", x.sharding)
jax.tree.map(inspect, sharded_train_state)
pass
()
REPLICATED NamedSharding(mesh=Mesh('x': 8, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=unpinned_host)
(4,)
(16, 4)
(4,)
(4, 4)
┌─────────────────┐
│MULTIMESH 0,1,2,3│
└─────────────────┘
┌───────────┐
│MULTIMESH 0│
├───────────┤
│MULTIMESH 1│
├───────────┤
│MULTIMESH 2│
├───────────┤
│MULTIMESH 3│
└───────────┘
┌─────────────────┐
│MULTIMESH 4,5,6,7│
└─────────────────┘
┌───────────────────────┐
│      MULTIMESH 4      │
├───────────────────────┤
│      MULTIMESH 5      │
├───────────────────────┤
│      MULTIMESH 6      │
├───────────────────────┤
│      MULTIMESH 7      │
└───────────────────────┘