multimesh.jax.flax.parallelize_step#
- multimesh.jax.flax.parallelize_step(model, optimizer, mesh, global_batch=None, local_batch=None, replicate_inputs=True)#
Combines model loss function and optimizer into an autosharded step function
- Parameters:
model (Type[Module]) – Flax model to automatically convert to a full training step
optimizer (GradientTransformation) – The optimizer to use for the parameters
mesh (Mesh) – a Mesh context defining the devices and mesh shape
global_batch (Any | None) – optional, an example batch giving the full (global) input shapes. This must be an array with global sharding information.
local_batch (Any | None) – optional, an example per-process batch giving the local input shape.
replicate_inputs (bool)
- Returns:
If
global_batch
is given, returns a tuple of (sharded_model
,initial_params
,mesh
). Iflocal_batch
is given, returns a tuple of (sharded_model
,initial_parmas
,mesh
,prepare_batch
). The returnedinitial_params
contains fully initialized, sharded parameters. Themesh
is a mesh context suitable for MPMD invocations ofsharded_model
. Theprepare_batch
converts a local per-process batch into a global, sharded array.
Example
See the tutorial notebooks for complete examples.