multimesh.jax.flax.shard_axes

Contents

multimesh.jax.flax.shard_axes#

multimesh.jax.flax.shard_axes(*args)#

Helper function to simplify naming axes of Flax parameters

Parameters:

*args (str) – Sequence of axis names to apply to a Flax kernel

Returns:

A partitioning function that will apply logical names to the axes of module parameters

Return type:

Callable

Example

>>> import numpy as np
>>> import jax
>>> from multimesh.jax.flax import shard_axes
>>> from flax import linen as nn
>>> from jax.sharding import Mesh
>>>
>>> model = Dense(features=16, kernel_init=shard_axes("batch", "model"))
>>>
>>> x =  jnp.ones((16,9))
>>> model.tabulate(jax.random.key(0), x))

Which produces the output:

                                Dense Summary
┏━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs        ┃ outputs        ┃ params                                ┃
┡━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      │ Dense  │ float32[16,9] │ float32[16,16] │ bias: float32[16]                     │
│      │        │               │                │ kernel: float32[9,16] P(batch, model) │
│      │        │               │                │                                       │
│      │        │               │                │ 160 (640 B)                           │
├──────┼────────┼───────────────┼────────────────┼───────────────────────────────────────┤
│      │        │               │          Total │ 160 (640 B)                           │
└──────┴────────┴───────────────┴────────────────┴───────────────────────────────────────┘

                            Total Parameters: 160 (640 B)