multimesh.jax.microbatch#
- multimesh.jax.microbatch(fun, dim, size, argnum=0, interleave=None, num_stages=None, schedule=None, unrolling=None, arg_shardings=None, custom_schedule=None)#
Unrolls a function along an axis into a microbatch loop.
The input tensors are sliced along the axis for each microbatch. The results of each microbatch are sum-reduced to produce the final result. The output tensors should be equivalent (modulo precision) to the function without the transform. No semantic checking is currently done on the function to ensure that sum-reduction of microbatches is equivalent to the original function and relies on the user ensuring the transformation is equivalent.
For a single input/output, the transformation is equivalent to:
import jax import jax.numpy as jnp def microbatch_f(x, params): result = jax.eval_shape(fun, x, params) accumulator = jnp.zeros(result.shape, dtype=result.dtype) num_loops = x.shape[dim] // size for i in range(num_loops): microbatch_x = jnp.dynamic_slice(x, ...) accumulator += f(microbatch_x, params) return accumulator
A microbatch loop will usually be a nested loop of N iterations over S stages:
for mb in range(num_microbatches): for stage in range(num_stages): ...
Microbatches are assumed to be independent and the iteration order for
mb
is arbitary. Microbatches can be tiled or unrolled in any order with the stages, e.g.for block in range(blocks): for stage in range(num_stages): for mb in range(unrolling): ...
The structure of the nested loops can be tuned by specifying
schedule
,interleave
,unrolling
, andnum_stages
parameters. If left unspecified, the compiler/runtime is free to choose the microbatch schedule.- Parameters:
fun – Function to be transformed into microbatch loops.
fun
should be pure. See documentation for jax.jit for requirements forfun
.dim (int) – an int specifying which dimension of the input tensor(s) should be sliced for each microbatch
size (int) – an int specifying the size of the microbatch dimension for each microbatch
argnum (int) – optional, the argument number that will be sliced for each microbatch. If the argument is a pytree of tensors rather than a single tensor, then all tensors in the tree are sliced along the given
dim
. All other arguments tofun
are unmodified.interleave (int | None) – optional, a hint to the microbatch scheduler about how tasks within the loop should be interleavd.
schedule (Literal['1f1b', 'gpipe', 'wavefront', 'custom', 'zero-bubble-h2'] | None) – optional, a string identifying the schedule of microbatch iterations/stages such as ‘gpipe’ or ‘1f1b’.
unrolling (int | None) – optional, an int specifying the unrolling of the microbatch loop. By default, the microbatch loop is fully unrolled. Only relevant for the gpipe schedule.
arg_shardings (Any | None) – optional, an object or (prefix) pytree of objects matching
argnum
with shardings. The shardings can be any sharding-equivalent object including partition specs orNamedSharding
If specified, this applies the sharding annotations to all sliced inputs.custom_schedule (list[list[str]] | list[list[tuple[int, str]]] | None) – optional, a list of lists of strings or tuples of (int, string) specifying the custom pipeline schedule of microbatch iterations/stages when
schedule
is ‘custom’. The custom schedule Each list in custom_schedule corresponds to a device mesh, and each element in that list is a task name, optionally associated with a specific microbatch. There must be exactly pipeline_depth device meshes in the custom schedule, with each task appearing exactly num_microbatches times in the schedule.num_stages (int | None)
- Returns:
A wrapped version of
fun
that executes as a microbatch loop.