multimesh.jax.flax.task#
- multimesh.jax.flax.task(model, name=None, *, task=None, **kwargs)#
Wraps a Flax module in an auto-sharding task context
- Parameters:
model (Type[Module]) – Flax model to be encapsulated as a task.
name (str | None) – optional, a metadata name to assign to the task context
mesh – optional, a Mesh context defining the devices and mesh shape
devices – optional, a numpy array or list of jax devices specifying the devices to include in the task submesh. One of
mesh
ordevices
must be given. Ifdevices
is a numpy array, the mesh shape is inferred from the shape of the device array.device_axes – optional, a list of names to assign to each device axis. The number of names must match the shape of
mesh
ordevices
. If None, logical sharding constraints will be translated to replicated sharding.logical_axes – optional, a list of string pairs (‘logical’, ‘device’) giving the translation from logical names to physical device names. The logical names should match those passed to
with_sharding_constraint
calls within the task. If None, the device axis names are used directly for autosharding. Raises aValueError
iflogical_axes
are given but nomesh
ordevice_axes
are specified.task (Task | None)
- Returns:
A wrapped version of
model
usable as a submesh task.
Example
>>> import numpy as np >>> import jax >>> from multimesh.jax.flax import shard_axes, task >>> from flax import linen as nn >>> from jax.sharding import Mesh >>> >>> devices = np.array(jax.devices()).reshape(2,2) >>> mesh = Mesh(devices, ("x", "y")) >>> DenseTask = task(nn.Dense, mesh=mesh) >>> model = DenseTask(features=16, kernel_init=shard_axes("x", None)) >>> >>> x = jnp.ones((16,9)) >>> model.tabulate(jax.random.key(0), x))
Which produces the output:
Task[Dense] Summary ┏━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃ ┡━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ Task[Dense] │ float32[16,9] │ float32[16,16] │ mod: │ │ │ │ │ │ params: │ │ │ │ │ │ bias: float32[16] │ │ │ │ │ │ kernel: float32[9,16] P(x, None) │ │ │ │ │ │ │ │ │ │ │ │ 160 (640 B) │ ├──────┼─────────────┼───────────────┼────────────────┼──────────────────────────────────────┤ │ │ │ │ Total │ 160 (640 B) │ └──────┴─────────────┴───────────────┴────────────────┴──────────────────────────────────────┘ Total Parameters: 160 (640 B)