MultiMesh contexts for MPMD#

What you need to know for this tutorial#

  1. A basic knowledge of Jax and numpy

  2. Familiarity with the concepts of tensor sharding and parallelism

  3. The “logical autosharding for SPMD” tutorial

Initialize MultiMesh for Jax#

As before, the first thing we need to is to initialize MultiMesh for Jax.

from jax_plugins.multimesh import init

init(cpus=8)

Import Jax packages#

Once MultiMesh for has been initialized, a standard set of Jax imports can be done. We now import a new class MultiMesh.

import jax
import jax.numpy as jnp
import numpy as np
import re
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.debug import visualize_array_sharding
from multimesh.jax import with_sharding_constraint, parallelize, MultiMesh, Task

Model code remains unmodified with MultiMesh context#

As in the other tutorials, we can write model code that still defines names tensor axes using sharding constraints.

def f(param, x):
    return jnp.einsum("bm,mh->bh", x, param)

def model(params, x):
    x = with_sharding_constraint(x, P("batch", None))
    param_1, param_2 = params
    param_1 = with_sharding_constraint(param_1, P("model", None))
    param_2 = with_sharding_constraint(param_2, P(None, "model"))
    x = f(param_1, x)
    x = f(param_2, x)
    return with_sharding_constraint(x, P("batch", None))

To create an MPMD code, we need to assign named scopes to each layer. More commonly we would write Flax modules that defined named scopes. For this tutorial, we just assigned named scopes using jax.named_scope directly.

def model(params, x):
    x = with_sharding_constraint(x, P("batch", None))
    param_1, param_2 = params
    param_1 = with_sharding_constraint(param_1, P("model", None))
    param_2 = with_sharding_constraint(param_2, P(None, "model"))
    with jax.named_scope("layer0"):  
      x = f(param_1, x)
    with jax.named_scope("layer1"):
      x = f(param_2, x)
      return with_sharding_constraint(x, P("batch", None))

As before, to run the model, we will need input batches and parameters. We again use dummy functions:

def get_input_batch():
    return jnp.arange(64).reshape(8,8)

def init_params():
    param1 = jnp.arange(64).reshape(8,8)
    param2 = jnp.arange(64).reshape(8,8)
    return param1, param2

We now define a MultiMesh context that will defines submeshes for the named scopes. We start by creating a standard Jax mesh and wrapping it with MultiMesh, but with a “stage” axis that will be sliced along to create pipeline stages.

global_mesh = Mesh(np.array(jax.devices()).reshape(2,2,2), ("stage", "batch", "model"))
multi_mesh = MultiMesh(global_mesh=global_mesh)

We now register task scopes with the MultiMesh wrapper. Rather than defining separate scopes for the specific layer names, we define a general callback that defines tasks as slices of the global mesh.

def callback(name: str, backprop: bool):
    layer = int(re.compile(r"layer(\d+)").search(name).groups()[0])
    suffix = "bwd" if backprop else "fwd"
    full_name = f"layer{layer}_" + suffix
    return Task(
        name=full_name,
        mesh_slice={"stage" : layer},
    )

multi_mesh.register_task(
    "layer\d+",
    callback=callback
)

We can now create a sharded model using multimesh.jax.parallelize inside a MultiMesh context.

with multi_mesh:
    sharded_model, init_sharded_params, get_sharded_batch = parallelize(
        model, init_params=init_params, get_input_batch=get_input_batch
    )
2025-07-21 15:26:30.296109: W ./xla/service/hlo_module_config.h:194] Warning: Using auto_spmd_partitioning. It is experimental and may contain bugs!

We can then inspect the generated shardings for the parameters. The layer 0 and 1 tasks are defined by creating a submesh scope that slices the mesh along the stage axis. We can visualize the shardings, showing how param 1 and param 2 are defined on different 2x2 submeshes.

param_1, param_2 = init_sharded_params()
print("Param 1")
visualize_array_sharding(param_1, use_color=False)
print("Param 2")
visualize_array_sharding(param_2, use_color=False)
Param 1
Param 2
┌───────────────────────┐
│                       │
│     MULTIMESH 0,2     │
│                       │
│                       │
├───────────────────────┤
│                       │
│     MULTIMESH 1,3     │
│                       │
│                       │
└───────────────────────┘
┌─────────────┬─────────────┐
│             │             │
│             │             │
│             │             │
│             │             │
│MULTIMESH 4,6│MULTIMESH 5,7│
│             │             │
│             │             │
│             │             │
│             │             │
└─────────────┴─────────────┘