multimesh.jax.parallelize#
- multimesh.jax.parallelize(fun, *, init_params=None, get_input_batch=None, initial_batch=None, devices=None)#
Compiles an abstract function into a sharded function
Parallelizes a function with logical sharding annotations into a function with explicit device shardings for all inputs and parameters. Input and output shapes are derived from the
init_params
andget_input_batch
functions. Functions to be parallelized must conform to a standard format in which parameters are the first argument and input batches are the seocnd argument. This matches the format for thegrad
transformation which differentiates the first argument and assumes the second (and later) arguments are input batches.- Parameters:
fun (Callable) – Function to be parallelized.
fun
should be pure. See documentation for jax.jit for requirements forfun
.init_params (Callable[[], Any] | None) – optional, a function taking no arguments that generates input parameters without shardings. Optional if
fun
does not take parameters as a first argument.get_input_batch (Callable[[], Any] | None) – optional, a function taking no arguments that generates input batches without shardings. Optional if
fun
does not take input batches as a second argument or ifinitial_batch
is given instead.initial_batch (Any | None) – optional, an array or pytree of arrays with shardings valid as input parameters to
fun
devices (Sequence[Device] | ndarray | None) – optional, a numpy array or list of jax devices specifying the devices to parallelize over. If not given, the function is parallelized over all devices.
- Returns:
A tuple of (sharded_fun, sharded_init_params, sharded_get_input_batch) if
get_input_batch
is given or a tuple (sharded_fun, sharded_init_params) if a shardedinitial_batch
is given.
Examples
>>> import jax.numpy as jnp >>> import jax >>> from multimesh.jax import parallelize, task, with_sharding_constraint >>> from jax.sharding import PartitionSpec as P >>> >>> def f(x,y): ... x = with_sharding_constraint(x, P("x",)) ... y = with_sharding_constraint(y, P("x",)) ... out = x*y ... return with_sharding_constraint(out, P("batch", "model")) >>> >>> devices = np.array(jax.devices()).reshape(2,2) >>> task_f = task(f, devices=devices, device_axes=("x",)) >>> >>> def init(): >>> return jnp.arange(16) >>> >>> sh_f, sh_init_params, sh_init_batch = parallelize(f, init, init)