Public API: multimesh.jax package#
Subpackages#
Task transformations#
|
Wraps a function in an auto-sharding task context |
|
Unrolls a function along an axis into a microbatch loop. |
|
Compiles an abstract function into a sharded function |
|
Applies logical sharding annotations to the input(s) |
|
Registers a name or regex-based task autosharding context |
|
MultiMesh represents a global mesh with named submeshes |
|
Task dataclass defining all options for a named task context |
|
TaskMesh representing a submesh within the global mesh |
Task classes#
- class multimesh.jax.MultiMesh(*, axis_sizes=None, axis_names=None, devices=None, global_mesh=None)#
MultiMesh represents a global mesh with named submeshes
MultiMesh defines a global mesh with named submeshes. The mesh can can be sliced along certain dimensions to create submeshes with a specific subshape. The MultiMesh can either create a new global mesh or wrap an existing jax.sharding.Mesh.
- Parameters:
axis_sizes (Optional[Sequence[int]]) – The sizes of each global mesh axis. Defaults to None. Ignored if global_mesh is given.
axis_names (Optional[Sequence[str]]) – The names of each global mesh axis. Defaults to None. Ignored if global_mesh is given.
devices (Optional[np.array[xc.Device]]) – The shaped array of devices to include in the global mesh. The shape should match axis_sizes. Defaults to None. If not specified, the default jax.devices() will be used and reshaped to match the axis_sizes. Ignored if global_mesh is given.
global_mesh (Optional[Mesh]) – A Jax Mesh defining devices, axis_sizes, and axis_names. Defaults to None. If given, all other parameters will be ignored. If not given, then axis_sizes and axis_names must be given.
- slice(**kwargs)#
Slice the global mesh along the specified axis.
Slices a mesh along the given named dimensions as defined by the kwargs map.
- class multimesh.jax.Task(name=None, mesh=None, logical_axes=None, mesh_slice=None, extra_axes=None, split_backprop=None, loop_dependent_mesh_slice=None)#
Task dataclass defining all options for a named task context
- class multimesh.jax.TaskMesh(*, devices, axis_sizes=None, axis_names=None, abstract=None)#
TaskMesh representing a submesh within the global mesh
TaskMesh define the submesh devices, axis names, and axis sizes for a task that runs on a subset or slice of the global mesh
- Parameters:
devices (Sequence[int]) – The list of device numbers to include in the task. Device numbers correspond to a [O,N) relative numbering of devices within the executable, not global device numbers.
axis_sizes (Optional[Sequence[int]]) – The sizes of each task mesh axis. Defaults to None.
axis_names (Optional[Sequence[str]]) – The names of each task mesh axis. Defaults to None.
abstract (Optional[AbstractMesh]) – A Jax AbstractMesh defining both axis_sizes and axis_names. Defaults to None.
Compiler contexts#
|
Opens a context where shardings constraints become logical shardings. |
|
Helper function to configure multiple contexts in a single manager. |
|
(en|dis)ables native Jax fallthrough path for jitted functions. |
|
(dis|en)ables comm-avoiding recompute of inter-task intermediates. |
|
Sets debug context where all MultiMesh transformations are no-ops. |