MultiMesh for JAX: MPMD workflows for JAX#
MultiMesh for JAX provides a framework for creating task contexts within jitted computations, allowing different subcomputations to be placed on different GPU submeshes. These task computations can be combined inside a global jit with data resharding across submeshes occurring automatically. MultiMesh therefore enables pipeline parallelism to be easily expressed. Standard Jax SPMD sharding idioms can be used within each task, enabling full N-dimensional parallelism. Readers can find more architecture details or get started using it.
MultiMesh for JAX uses idiomatic JAX transforms to assign submesh contexts to functions or Flax modules
MultiMesh for JAX adds compiler and runtime functionality via the jax_plugins interface