multimesh.jax.task

Contents

multimesh.jax.task#

multimesh.jax.task(fun, name=None, *, task=None, **kwargs)#

Wraps a function in an auto-sharding task context

Parameters:
  • fun (Callable | Type) – Function to be encapsulated as a task. fun should be pure. See documentation for jax.jit for requirements for fun.

  • name (str | None) – optional, a metadata name to assign to the task context

  • task (Task | None) – the Task dataclass defining the submesh and sharding

  • kwargs – options passed through to Task dataclass constructor

Returns:

A wrapped version of fun usable as a submesh task.

Examples

>>> import numpy as np
>>> import jax
>>> from multimesh.jax import task, with_sharding_constraint
>>> from jax.sharding import PartitionSpec as P
>>>
>>> def f(x):
...   x = with_sharding_constraint(x, P("batch", "model"))
...   out = x*x
...   return with_sharding_constraint(out, P("batch", "model"))
>>>
>>> devices = np.array(jax.devices()).reshape(2,2)
>>> task_f = task(f, devices=devices,
...               device_axes=("x", "y"),
...               logical_axes=(
...                 ("batch", "x"),
...                 ("model", "y"),
...               ))
>>>