multimesh.jax.flax.with_sharding_constraint#
- multimesh.jax.flax.with_sharding_constraint(x, axis_resources)#
Applies logical sharding annotations to the input(s)
Matches the semantics of jax.lax.with_sharding_constraint. The shardings (
axis_resources
) should bePartitionSpec
orNamedSharding
objects that have named axes. If no autosharding context is active, this forwards to jax.lax.with_sharding_constraint.- Parameters:
- Returns:
The input x with sharding constraints.
Example
>>> from jax_plugins.multimesh import init >>> init(cpus=4) >>> >>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import PartitionSpec as P, Mesh >>> from jax.experimental.pjit import pjit >>> from multimesh.jax import with_sharding_constraint, task, autoshard >>> >>> mesh = Mesh(np.array(jax.devices()), ("x",)) >>> def f(): ... x = jnp.arange(16) ... return with_sharding_constraint(x, P("x")) >>> >>> jit_f = pjit(task(f, mesh=mesh)) >>> with mesh, autoshard(True): ... y = jit_f() >>> jax.debug.visualize_array_sharding(y) ┌────────┬────────┬────────┬────────┐ │LEGATE 0│LEGATE 1│LEGATE 2│LEGATE 3│ └────────┴────────┴────────┴────────┘