multimesh.jax.flax.shard_axes#
- multimesh.jax.flax.shard_axes(*args)#
Helper function to simplify naming axes of Flax parameters
- Parameters:
*args (str) – Sequence of axis names to apply to a Flax kernel
- Returns:
A partitioning function that will apply logical names to the axes of module parameters
- Return type:
Example
>>> import numpy as np >>> import jax >>> from multimesh.jax.flax import shard_axes >>> from flax import linen as nn >>> from jax.sharding import Mesh >>> >>> model = Dense(features=16, kernel_init=shard_axes("batch", "model")) >>> >>> x = jnp.ones((16,9)) >>> model.tabulate(jax.random.key(0), x))
Which produces the output:
Dense Summary ┏━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃ ┡━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ Dense │ float32[16,9] │ float32[16,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[9,16] P(batch, model) │ │ │ │ │ │ │ │ │ │ │ │ 160 (640 B) │ ├──────┼────────┼───────────────┼────────────────┼───────────────────────────────────────┤ │ │ │ │ Total │ 160 (640 B) │ └──────┴────────┴───────────────┴────────────────┴───────────────────────────────────────┘ Total Parameters: 160 (640 B)