multimesh.jax.register_task#
- multimesh.jax.register_task(matcher, *, callback=None, task=None)#
Registers a name or regex-based task autosharding context
- Parameters:
matcher (str) – A name or regular expression with match group. This should match the name of a Flax module or a name passed to
jax.with_named_scope
.callback (Callable[[str, bool], Task] | None) – optional, a callback taking the match group from
matcher
and a bool indicating whether it is backprop. The callback must return aTask
defining the mesh and other attributes of the task scopetask (Task | None) – optional, a
Task
defining the mesh and other attributes of the task scope
- Returns:
None
Examples
Basic usage with mesh argument is:
>>> import numpy as np >>> import jax >>> from multimesh.jax import register_task >>> from jax.sharding import PartitionSpec as P, AbstractMesh, Mesh >>> >>> def f(x): ... with jax.named_scope("subtask"): ... x = with_sharding_constraint(x, P("batch", "model")) ... out = x*x ... return with_sharding_constraint(out, P("batch", "model")) >>> >>> def callback(name: str, backprop: bool): >>> mesh = TaskMesh(devices=(0,1,2,3), axis_names=("x", "y"), axis_sizes=(2,2)) >>> return Task(mesh=mesh, name=name) >>> register_task("subtask", callback=callback)