multimesh.jax.register_task

Contents

multimesh.jax.register_task#

multimesh.jax.register_task(matcher, *, callback=None, task=None)#

Registers a name or regex-based task autosharding context

Parameters:
  • matcher (str) – A name or regular expression with match group. This should match the name of a Flax module or a name passed to jax.with_named_scope.

  • callback (Callable[[str, bool], Task] | None) – optional, a callback taking the match group from matcher and a bool indicating whether it is backprop. The callback must return a Task defining the mesh and other attributes of the task scope

  • task (Task | None) – optional, a Task defining the mesh and other attributes of the task scope

Returns:

None

Examples

Basic usage with mesh argument is:

>>> import numpy as np
>>> import jax
>>> from multimesh.jax import register_task
>>> from jax.sharding import PartitionSpec as P, AbstractMesh, Mesh
>>>
>>> def f(x):
...   with jax.named_scope("subtask"):
...     x = with_sharding_constraint(x, P("batch", "model"))
...     out = x*x
...     return with_sharding_constraint(out, P("batch", "model"))
>>>
>>> def callback(name: str, backprop: bool):
>>>    mesh = TaskMesh(devices=(0,1,2,3), axis_names=("x", "y"), axis_sizes=(2,2))
>>>    return Task(mesh=mesh, name=name)
>>> register_task("subtask", callback=callback)