Public API: multimesh.jax package#

Subpackages#

Task transformations#

task(fun[, name, task])

Wraps a function in an auto-sharding task context

microbatch(fun, dim, size[, argnum, ...])

Unrolls a function along an axis into a microbatch loop.

parallelize(fun, *[, init_params, ...])

Compiles an abstract function into a sharded function

with_sharding_constraint(x, axis_resources)

Applies logical sharding annotations to the input(s)

register_task(matcher, *[, callback, task])

Registers a name or regex-based task autosharding context

MultiMesh(*[, axis_sizes, axis_names, ...])

MultiMesh represents a global mesh with named submeshes

Task([name, mesh, logical_axes, mesh_slice, ...])

Task dataclass defining all options for a named task context

TaskMesh(*, devices[, axis_sizes, ...])

TaskMesh representing a submesh within the global mesh

Task classes#

class multimesh.jax.MultiMesh(*, axis_sizes=None, axis_names=None, devices=None, global_mesh=None)#

MultiMesh represents a global mesh with named submeshes

MultiMesh defines a global mesh with named submeshes. The mesh can can be sliced along certain dimensions to create submeshes with a specific subshape. The MultiMesh can either create a new global mesh or wrap an existing jax.sharding.Mesh.

Parameters:
  • axis_sizes (Optional[Sequence[int]]) – The sizes of each global mesh axis. Defaults to None. Ignored if global_mesh is given.

  • axis_names (Optional[Sequence[str]]) – The names of each global mesh axis. Defaults to None. Ignored if global_mesh is given.

  • devices (Optional[np.array[xc.Device]]) – The shaped array of devices to include in the global mesh. The shape should match axis_sizes. Defaults to None. If not specified, the default jax.devices() will be used and reshaped to match the axis_sizes. Ignored if global_mesh is given.

  • global_mesh (Optional[Mesh]) – A Jax Mesh defining devices, axis_sizes, and axis_names. Defaults to None. If given, all other parameters will be ignored. If not given, then axis_sizes and axis_names must be given.

slice(**kwargs)#

Slice the global mesh along the specified axis.

Slices a mesh along the given named dimensions as defined by the kwargs map.

Parameters:
  • kwargs (Mapping[str, int]) – (Mapping[str,int]). A mapping defining

  • axis – value pairs that will be sliced out of

  • mesh. (the global)

Returns:

A submesh sliced along the specified dimensions.

Return type:

Mesh

class multimesh.jax.Task(name=None, mesh=None, logical_axes=None, mesh_slice=None, extra_axes=None, split_backprop=None, loop_dependent_mesh_slice=None)#

Task dataclass defining all options for a named task context

Parameters:
class multimesh.jax.TaskMesh(*, devices, axis_sizes=None, axis_names=None, abstract=None)#

TaskMesh representing a submesh within the global mesh

TaskMesh define the submesh devices, axis names, and axis sizes for a task that runs on a subset or slice of the global mesh

Parameters:
  • devices (Sequence[int]) – The list of device numbers to include in the task. Device numbers correspond to a [O,N) relative numbering of devices within the executable, not global device numbers.

  • axis_sizes (Optional[Sequence[int]]) – The sizes of each task mesh axis. Defaults to None.

  • axis_names (Optional[Sequence[str]]) – The names of each task mesh axis. Defaults to None.

  • abstract (Optional[AbstractMesh]) – A Jax AbstractMesh defining both axis_sizes and axis_names. Defaults to None.

place(devices)#

Creates an equivalent task mesh placed on new devices

Parameters:

devices (Sequence[int]) – The new devices to use for the mesh

Returns:

A task mesh with the same axes placed on new devices

Return type:

TaskMesh

Compiler contexts#

autoshard([autoshard])

Opens a context where shardings constraints become logical shardings.

context([autoshard, enable_recomputation, ...])

Helper function to configure multiple contexts in a single manager.

enable_fast_path([enable, context_value])

(en|dis)ables native Jax fallthrough path for jitted functions.

enable_recomputation([enable, context_value])

(dis|en)ables comm-avoiding recompute of inter-task intermediates.

ignore_transforms([ignore])

Sets debug context where all MultiMesh transformations are no-ops.