multimesh.jax.flax.parallelize_step

multimesh.jax.flax.parallelize_step#

multimesh.jax.flax.parallelize_step(model, optimizer, mesh, global_batch=None, local_batch=None, replicate_inputs=True)#

Combines model loss function and optimizer into an autosharded step function

Parameters:
  • model (Type[Module]) – Flax model to automatically convert to a full training step

  • optimizer (GradientTransformation) – The optimizer to use for the parameters

  • mesh (Mesh) – a Mesh context defining the devices and mesh shape

  • global_batch (Any | None) – optional, an example batch giving the full (global) input shapes. This must be an array with global sharding information.

  • local_batch (Any | None) – optional, an example per-process batch giving the local input shape.

  • replicate_inputs (bool)

Returns:

If global_batch is given, returns a tuple of (sharded_model, initial_params, mesh). If local_batch is given, returns a tuple of (sharded_model, initial_parmas, mesh, prepare_batch). The returned initial_params contains fully initialized, sharded parameters. The mesh is a mesh context suitable for MPMD invocations of sharded_model. The prepare_batch converts a local per-process batch into a global, sharded array.

Example

See the tutorial notebooks for complete examples.