Public API: multimesh.jax.flax package

Public API: multimesh.jax.flax package#

Task transformations#

task(model[, name, task])

Wraps a Flax module in an auto-sharding task context

parallelize_step(model, optimizer, mesh[, ...])

Combines model loss function and optimizer into an autosharded step function

Logical sharding annotations#

shard_axes(*args)

Helper function to simplify naming axes of Flax parameters

with_sharding_constraint(x, axis_resources)

Applies logical sharding annotations to the input(s)