multimesh.jax.flax.parallelize_step#
- multimesh.jax.flax.parallelize_step(model, optimizer, mesh, global_batch=None, local_batch=None, replicate_inputs=True)#
Combines model loss function and optimizer into an autosharded step function
- Parameters:
model (Type[Module]) – Flax model to automatically convert to a full training step
optimizer (GradientTransformation) – The optimizer to use for the parameters
mesh (Mesh) – a Mesh context defining the devices and mesh shape
global_batch (Any | None) – optional, an example batch giving the full (global) input shapes. This must be an array with global sharding information.
local_batch (Any | None) – optional, an example per-process batch giving the local input shape.
replicate_inputs (bool)
- Returns:
If
global_batchis given, returns a tuple of (sharded_model,initial_params,mesh). Iflocal_batchis given, returns a tuple of (sharded_model,initial_parmas,mesh,prepare_batch). The returnedinitial_paramscontains fully initialized, sharded parameters. Themeshis a mesh context suitable for MPMD invocations ofsharded_model. Theprepare_batchconverts a local per-process batch into a global, sharded array.
Example
See the tutorial notebooks for complete examples.