Index _ | A | C | E | I | M | P | R | S | T | W _ __init__() (multimesh.jax.MultiMesh method) (multimesh.jax.Task method) (multimesh.jax.TaskMesh method) A autoshard() (in module multimesh.jax) C context() (in module multimesh.jax) E enable_fast_path() (in module multimesh.jax) enable_recomputation() (in module multimesh.jax) I ignore_transforms() (in module multimesh.jax) M microbatch() (in module multimesh.jax) MultiMesh (class in multimesh.jax), [1] P parallelize() (in module multimesh.jax) parallelize_step() (in module multimesh.jax.flax) place() (multimesh.jax.TaskMesh method) R register_task() (in module multimesh.jax) S shard_axes() (in module multimesh.jax.flax) slice() (multimesh.jax.MultiMesh method) T Task (class in multimesh.jax), [1] task() (in module multimesh.jax) (in module multimesh.jax.flax) TaskMesh (class in multimesh.jax), [1] W with_sharding_constraint() (in module multimesh.jax) (in module multimesh.jax.flax)