Public API: multimesh.jax.flax package#
Task transformations#
|
Wraps a Flax module in an auto-sharding task context |
|
Combines model loss function and optimizer into an autosharded step function |
Logical sharding annotations#
|
Helper function to simplify naming axes of Flax parameters |
|
Applies logical sharding annotations to the input(s) |