MLP pipeline with external task registration#

What you need to know for this tutorial#

  1. A basic knowledge of Jax and numpy

  2. Familiarity with the concepts of tensor sharding and parallelism

  3. Familiarity with basic neural network concepts like dense layers, activation functions, and backprop

  4. The previous tutorials on SPMD/MPMD autosharding with MultiMesh for and Flax

Optional, but helpful, to know#

  1. Flax package for building composable, trainable models in Jax

  2. Optax package with different optimizers

Initialize MultiMesh for Jax#

As in the previous examples, the first step in a MultiMesh for Jax program is to initialize the environment.

from jax_plugins.multimesh import init

init(cpus=8)

Import Jax packages#

Once MultiMesh for Jax has been initialized, the standard set of Jax imports can be done. In this case, we are also importing:

  1. Flax, which is used to define modules with parameters

  2. Optax, which provides optimizers for updating parameters

For more details on Flax and Optax, the user should consult the docs for these packages. For this tutorial, the usage of these packages should be relatively clear.

import jax
import jax.numpy as jnp
import numpy as np
import optax
from multimesh.jax import clear_tasks, register_task, Task, TaskMesh
from jax.sharding import Mesh, PartitionSpec as P
from jax.lax import with_sharding_constraint
from multimesh.jax.flax import parallelize_step
from flax import linen as nn
from functools import partial

Create single layer with sharding annotations#

We can now proceed to define a model called Pipeline. The building block for this model will be a basic MLP layer which is a nn.Dense followed by a nn.relu. The dense layer implicitly defines a parameter tensor. Two key observations:

  1. A sharding constraint is defined for the input activations x that names the tensor axes. The input activations are 1D with the axis named batch.

  2. Sharding is also defined for the 2D tensor in the dense layer. We give these axes the names data and model.

init_shard = partial(nn.with_partitioning, fn=nn.initializers.xavier_normal())

class MLPBlock(nn.Module):
    features: int = 4
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, P("batch"))
        x = nn.Dense(features=self.features,kernel_init=init_shard(names=("data","model")))(x)
        x = nn.relu(x)
        return x

Create full model with pipelined layers#

In contrast to the previous Flax example that uses explicit task transformations, we will use implicit task decompositions to define an MPMD pipeline. In this case, there are no explicit task definitions. In fact, both MLPBlock and Pipeline are defined without any multimesh.jax functions at all.

class Pipeline(nn.Module):
    n_layers: int = 2

    @nn.compact
    def __call__(self, x):
        for i in range(self.n_layers):
            layer = MLPBlock(name=f"layer_{i}")
            x = layer(x)

        return (x*x).sum()

Register implicit task blocks#

The Flax modules are defined with names layer_0, layer_1, etc. If these Flax names are globally unique, we can that unique name to define a task scope.

task_mesh = TaskMesh(devices=(0,1,2,3), axis_names=("x","y"), axis_sizes=(4,1))
register_task(
  "layer_0",
  task=Task( 
      mesh=task_mesh,
      logical_axes=[
        ("batch", "x"),
        ("data", "x"),
        ("model", "y"),
      ]
  )
)

register_task(
  "layer_1",
  task=Task( 
      mesh=task_mesh.place((4,5,6,7)),
      logical_axes=[
        ("batch", "x"),
        ("data", "x"),
        ("model", "y"),
      ]
  )
)

Initialize parameters and create parallel training function#

So far the code looks almost the same as standard Flax code with a few extra annotations. The full power of MultiMesh for Jax comes from its auto-parallelizing compiler that reads all the submesh annotations and transforms the global computation into a series of submesh tasks. Managing Jax training state and shardings can be difficult when directly calling jax.jit. The MultiMesh Flax wrappers provides a parallelize_step for training that allows users to generate MPMD auto-sharded models for Flax modules, similar to the previous MPMD example. The user simply passes in the model, a standard optax optimizer, and an example input batch. MultiMesh for Jax does all the work to derive all the shardings, compile the model, and initialize all the sharded training state. parallelize_step returns back the four objects necessary to run a training loop:

  1. A step_fn for computing gradients and updating parameters

  2. The randomly initialized sharded parameters

  3. A function for preprocessing input batches located in host memory into device arrays

  4. The mesh context to use for training steps

This setup is slightly different than the multimesh.jax.parallelize function, but is more idiomatic to Flax and training.

x = jnp.arange(16)
opt = optax.sgd(learning_rate=0.02)
mesh = Mesh(np.array(jax.devices()).reshape(8,1), ("x","y"))
step_fn, sharded_train_state, mesh, prepare_batch = parallelize_step(model=Pipeline(), mesh=mesh, optimizer=opt, local_batch=x)
WARNING:2025-07-17 18:58:28,206:jax._src.xla_bridge:830: Platform 'multimesh' is experimental and not all JAX functionality may be correctly supported!
2025-07-17 18:58:30.324074: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!

Run training steps#

Now that we have our initial training state (parameters + optimizer state), we can run training steps to compute the loss and increment parameters.

with mesh:
    batch = prepare_batch(x)
    loss, sharded_train_state = step_fn(sharded_train_state, batch)
    print("Loss = ", loss)
Loss =  6.547683

Inspecting MPMD sharding specs#

As before, to get a glimpse of the underlying details, we can inspect the sharding specs of the training state. Again there are 5 tensors.

  1. A single scalar parameter that is not partitioned (empty PartitionSpec()). This is the learning rate (optimizer state for SGD).

  2. A (4,) tensor replicated across devices 0-3 corresponding to the first layer bias

  3. A (16,4) tensor fully sharded across devices 0-3 corresponding to the fully connected weights in the first layer

  4. A (4,) tensor replicated across devices 4-7 corresponding to the second layer bias

  5. A (4,4) tensor fully sharded across devices 4-7 corresponding to the fully connected weights in the second layer

def inspect(x):
    print(x.shape)
    if len(x.shape) > 0:
        jax.debug.visualize_array_sharding(x, use_color=False)
    else:
        print("REPLICATED", x.sharding)
jax.tree.map(inspect, sharded_train_state)
pass
()
REPLICATED NamedSharding(mesh=Mesh('x': 8, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=unpinned_host)
(4,)
(16, 4)
(4,)
(4, 4)
┌─────────────────┐
│MULTIMESH 0,1,2,3│
└─────────────────┘
┌───────────┐
│MULTIMESH 0│
├───────────┤
│MULTIMESH 1│
├───────────┤
│MULTIMESH 2│
├───────────┤
│MULTIMESH 3│
└───────────┘
┌─────────────────┐
│MULTIMESH 4,5,6,7│
└─────────────────┘
┌───────────────────────┐
│      MULTIMESH 4      │
├───────────────────────┤
│      MULTIMESH 5      │
├───────────────────────┤
│      MULTIMESH 6      │
├───────────────────────┤
│      MULTIMESH 7      │
└───────────────────────┘

As in the previous example where multimesh.jax.task was used to explicitly define a task inside Flax code, we have instead mapped Flax names to a task context outside the Flax code. If a Flax library uses unique names, this allows the library to be mapped to Legate MPMD execution without any modifications to the library code.

Advanced example: dynamic device lists#

We can repeat the example above, but use a factory function instead to generate the task contexts. Rather than specifying a different task for layer_0, layer_1… we can register a single factory that dynamically computes the task context.

clear_tasks()

def callback(name: str, backprop: bool):
    n_per_stage = 4
    layer = int(name.split("_")[1])
    start = layer * n_per_stage
    end = start + n_per_stage
    task_mesh = TaskMesh(devices=range(start,end), axis_names=("x","y"), axis_sizes=(4,1))
    return Task(
      name=f"layer_{layer}",
      mesh=task_mesh,
      logical_axes=[
        ("batch", "x"),
        ("data", "x"),
        ("model", "y"),
      ]
    )

register_task(
  r"(layer_\d+)",
  callback=callback,
)

x = jnp.arange(16)
opt = optax.sgd(learning_rate=0.02)
mesh = Mesh(np.array(jax.devices()).reshape(8,1), ("x","y"))
step_fn, sharded_train_state, prepare_batch, mesh = parallelize_step(
    model=Pipeline(), optimizer=opt, mesh=mesh, local_batch=x)
2025-07-17 18:58:32.751950: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!

We can again inspect the parameter shardings

def inspect(x):
    print(x.shape)
    if len(x.shape) > 0:
        jax.debug.visualize_array_sharding(x, use_color=False)
    else:
        print("REPLICATED", x.sharding)
jax.tree.map(inspect, sharded_train_state)
pass
()
REPLICATED NamedSharding(mesh=Mesh('x': 8, 'y': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=unpinned_host)
(4,)
(16, 4)
(4,)
(4, 4)
┌─────────────────┐
│MULTIMESH 0,1,2,3│
└─────────────────┘
┌───────────┐
│MULTIMESH 0│
├───────────┤
│MULTIMESH 1│
├───────────┤
│MULTIMESH 2│
├───────────┤
│MULTIMESH 3│
└───────────┘
┌─────────────────┐
│MULTIMESH 4,5,6,7│
└─────────────────┘
┌───────────────────────┐
│      MULTIMESH 4      │
├───────────────────────┤
│      MULTIMESH 5      │
├───────────────────────┤
│      MULTIMESH 6      │
├───────────────────────┤
│      MULTIMESH 7      │
└───────────────────────┘