multimesh.jax.task#
- multimesh.jax.task(fun, name=None, *, task=None, **kwargs)#
Wraps a function in an auto-sharding task context
- Parameters:
fun (Callable | Type) – Function to be encapsulated as a task.
fun
should be pure. See documentation for jax.jit for requirements forfun
.name (str | None) – optional, a metadata name to assign to the task context
task (Task | None) – the Task dataclass defining the submesh and sharding
kwargs – options passed through to Task dataclass constructor
- Returns:
A wrapped version of
fun
usable as a submesh task.
Examples
>>> import numpy as np >>> import jax >>> from multimesh.jax import task, with_sharding_constraint >>> from jax.sharding import PartitionSpec as P >>> >>> def f(x): ... x = with_sharding_constraint(x, P("batch", "model")) ... out = x*x ... return with_sharding_constraint(out, P("batch", "model")) >>> >>> devices = np.array(jax.devices()).reshape(2,2) >>> task_f = task(f, devices=devices, ... device_axes=("x", "y"), ... logical_axes=( ... ("batch", "x"), ... ("model", "y"), ... )) >>>