multimesh.jax.flax.task

Contents

multimesh.jax.flax.task#

multimesh.jax.flax.task(model, name=None, *, task=None, **kwargs)#

Wraps a Flax module in an auto-sharding task context

Parameters:
  • model (Type[Module]) – Flax model to be encapsulated as a task.

  • name (str | None) – optional, a metadata name to assign to the task context

  • mesh – optional, a Mesh context defining the devices and mesh shape

  • devices – optional, a numpy array or list of jax devices specifying the devices to include in the task submesh. One of mesh or devices must be given. If devices is a numpy array, the mesh shape is inferred from the shape of the device array.

  • device_axes – optional, a list of names to assign to each device axis. The number of names must match the shape of mesh or devices. If None, logical sharding constraints will be translated to replicated sharding.

  • logical_axes – optional, a list of string pairs (‘logical’, ‘device’) giving the translation from logical names to physical device names. The logical names should match those passed to with_sharding_constraint calls within the task. If None, the device axis names are used directly for autosharding. Raises a ValueError if logical_axes are given but no mesh or device_axes are specified.

  • task (Task | None)

Returns:

A wrapped version of model usable as a submesh task.

Example

>>> import numpy as np
>>> import jax
>>> from multimesh.jax.flax import shard_axes, task
>>> from flax import linen as nn
>>> from jax.sharding import Mesh
>>>
>>> devices = np.array(jax.devices()).reshape(2,2)
>>> mesh = Mesh(devices, ("x", "y"))
>>> DenseTask = task(nn.Dense, mesh=mesh)
>>> model = DenseTask(features=16, kernel_init=shard_axes("x", None))
>>>
>>> x =  jnp.ones((16,9))
>>> model.tabulate(jax.random.key(0), x))

Which produces the output:

                                        Task[Dense] Summary
┏━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module      ┃ inputs        ┃ outputs        ┃ params                               ┃
┡━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      │ Task[Dense] │ float32[16,9] │ float32[16,16] │ mod:                                 │
│      │             │               │                │   params:                            │
│      │             │               │                │     bias: float32[16]                │
│      │             │               │                │     kernel: float32[9,16] P(x, None) │
│      │             │               │                │                                      │
│      │             │               │                │ 160 (640 B)                          │
├──────┼─────────────┼───────────────┼────────────────┼──────────────────────────────────────┤
│      │             │               │          Total │ 160 (640 B)                          │
└──────┴─────────────┴───────────────┴────────────────┴──────────────────────────────────────┘

                                Total Parameters: 160 (640 B)