multimesh.jax.parallelize

Contents

multimesh.jax.parallelize#

multimesh.jax.parallelize(fun, *, init_params=None, get_input_batch=None, initial_batch=None, devices=None)#

Compiles an abstract function into a sharded function

Parallelizes a function with logical sharding annotations into a function with explicit device shardings for all inputs and parameters. Input and output shapes are derived from the init_params and get_input_batch functions. Functions to be parallelized must conform to a standard format in which parameters are the first argument and input batches are the seocnd argument. This matches the format for the grad transformation which differentiates the first argument and assumes the second (and later) arguments are input batches.

Parameters:
  • fun (Callable) – Function to be parallelized. fun should be pure. See documentation for jax.jit for requirements for fun.

  • init_params (Callable[[], Any] | None) – optional, a function taking no arguments that generates input parameters without shardings. Optional if fun does not take parameters as a first argument.

  • get_input_batch (Callable[[], Any] | None) – optional, a function taking no arguments that generates input batches without shardings. Optional if fun does not take input batches as a second argument or if initial_batch is given instead.

  • initial_batch (Any | None) – optional, an array or pytree of arrays with shardings valid as input parameters to fun

  • devices (Sequence[Device] | ndarray | None) – optional, a numpy array or list of jax devices specifying the devices to parallelize over. If not given, the function is parallelized over all devices.

Returns:

A tuple of (sharded_fun, sharded_init_params, sharded_get_input_batch) if get_input_batch is given or a tuple (sharded_fun, sharded_init_params) if a sharded initial_batch is given.

Examples

>>> import jax.numpy as jnp
>>> import jax
>>> from multimesh.jax import parallelize, task, with_sharding_constraint
>>> from jax.sharding import PartitionSpec as P
>>>
>>> def f(x,y):
...   x = with_sharding_constraint(x, P("x",))
...   y = with_sharding_constraint(y, P("x",))
...   out = x*y
...   return with_sharding_constraint(out, P("batch", "model"))
>>>
>>> devices = np.array(jax.devices()).reshape(2,2)
>>> task_f = task(f, devices=devices, device_axes=("x",))
>>>
>>> def init():
>>>   return jnp.arange(16)
>>>
>>> sh_f, sh_init_params, sh_init_batch = parallelize(f, init, init)