multimesh.jax.with_sharding_constraint

multimesh.jax.with_sharding_constraint#

multimesh.jax.with_sharding_constraint(x, axis_resources)#

Applies logical sharding annotations to the input(s)

Matches the semantics of jax.lax.with_sharding_constraint. The shardings (axis_resources) should be PartitionSpec or NamedSharding objects that have named axes. If no autosharding context is active, this forwards to jax.lax.with_sharding_constraint.

Parameters:
  • x (Any) – array or pytree of arrays to apply shardings to

  • axis_resources (Any) – sharding or pytree of shardings. This should be a PartitionSpec or NamedSharding or tuple of axis names.

Returns:

The input x with sharding constraints.

Example

>>> from jax_plugins.multimesh import init
>>> init(cpus=4)
>>>
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import PartitionSpec as P, Mesh
>>> from jax.experimental.pjit import pjit
>>> from multimesh.jax import with_sharding_constraint, task, autoshard
>>>
>>> mesh = Mesh(np.array(jax.devices()), ("x",))
>>> def f():
...   x = jnp.arange(16)
...   return with_sharding_constraint(x, P("x"))
>>>
>>> jit_f = pjit(task(f, mesh=mesh))
>>> with mesh, autoshard(True):
...   y = jit_f()
>>> jax.debug.visualize_array_sharding(y)
  ┌────────┬────────┬────────┬────────┐
  │LEGATE 0│LEGATE 1│LEGATE 2│LEGATE 3│
  └────────┴────────┴────────┴────────┘