multimesh.jax.microbatch

Contents

multimesh.jax.microbatch#

multimesh.jax.microbatch(fun, dim, size, argnum=0, interleave=None, num_stages=None, schedule=None, unrolling=None, arg_shardings=None, custom_schedule=None)#

Unrolls a function along an axis into a microbatch loop.

The input tensors are sliced along the axis for each microbatch. The results of each microbatch are sum-reduced to produce the final result. The output tensors should be equivalent (modulo precision) to the function without the transform. No semantic checking is currently done on the function to ensure that sum-reduction of microbatches is equivalent to the original function and relies on the user ensuring the transformation is equivalent.

For a single input/output, the transformation is equivalent to:

import jax
import jax.numpy as jnp
def microbatch_f(x, params):
  result = jax.eval_shape(fun, x, params)
  accumulator = jnp.zeros(result.shape, dtype=result.dtype)
  num_loops = x.shape[dim] // size
  for i in range(num_loops):
      microbatch_x = jnp.dynamic_slice(x, ...)
      accumulator += f(microbatch_x, params)
  return accumulator

A microbatch loop will usually be a nested loop of N iterations over S stages:

for mb in range(num_microbatches):
  for stage in range(num_stages):
      ...

Microbatches are assumed to be independent and the iteration order for mb is arbitary. Microbatches can be tiled or unrolled in any order with the stages, e.g.

for block in range(blocks):
  for stage in range(num_stages):
    for mb in range(unrolling):
      ...

The structure of the nested loops can be tuned by specifying schedule, interleave, unrolling, and num_stages parameters. If left unspecified, the compiler/runtime is free to choose the microbatch schedule.

Parameters:
  • fun – Function to be transformed into microbatch loops. fun should be pure. See documentation for jax.jit for requirements for fun.

  • dim (int) – an int specifying which dimension of the input tensor(s) should be sliced for each microbatch

  • size (int) – an int specifying the size of the microbatch dimension for each microbatch

  • argnum (int) – optional, the argument number that will be sliced for each microbatch. If the argument is a pytree of tensors rather than a single tensor, then all tensors in the tree are sliced along the given dim. All other arguments to fun are unmodified.

  • interleave (int | None) – optional, a hint to the microbatch scheduler about how tasks within the loop should be interleavd.

  • schedule (Literal['1f1b', 'gpipe', 'wavefront', 'custom', 'zero-bubble-h2'] | None) – optional, a string identifying the schedule of microbatch iterations/stages such as ‘gpipe’ or ‘1f1b’.

  • unrolling (int | None) – optional, an int specifying the unrolling of the microbatch loop. By default, the microbatch loop is fully unrolled. Only relevant for the gpipe schedule.

  • arg_shardings (Any | None) – optional, an object or (prefix) pytree of objects matching argnum with shardings. The shardings can be any sharding-equivalent object including partition specs or NamedSharding If specified, this applies the sharding annotations to all sliced inputs.

  • custom_schedule (list[list[str]] | list[list[tuple[int, str]]] | None) – optional, a list of lists of strings or tuples of (int, string) specifying the custom pipeline schedule of microbatch iterations/stages when schedule is ‘custom’. The custom schedule Each list in custom_schedule corresponds to a device mesh, and each element in that list is a task name, optionally associated with a specific microbatch. There must be exactly pipeline_depth device meshes in the custom schedule, with each task appearing exactly num_microbatches times in the schedule.

  • num_stages (int | None)

Returns:

A wrapped version of fun that executes as a microbatch loop.