multimesh.jax.parallelize#
- multimesh.jax.parallelize(fun, *, init_params=None, get_input_batch=None, initial_batch=None, devices=None)#
Compiles an abstract function into a sharded function
Parallelizes a function with logical sharding annotations into a function with explicit device shardings for all inputs and parameters. Input and output shapes are derived from the
init_paramsandget_input_batchfunctions. Functions to be parallelized must conform to a standard format in which parameters are the first argument and input batches are the seocnd argument. This matches the format for thegradtransformation which differentiates the first argument and assumes the second (and later) arguments are input batches.- Parameters:
fun (Callable) – Function to be parallelized.
funshould be pure. See documentation for jax.jit for requirements forfun.init_params (Callable[[], Any] | None) – optional, a function taking no arguments that generates input parameters without shardings. Optional if
fundoes not take parameters as a first argument.get_input_batch (Callable[[], Any] | None) – optional, a function taking no arguments that generates input batches without shardings. Optional if
fundoes not take input batches as a second argument or ifinitial_batchis given instead.initial_batch (Any | None) – optional, an array or pytree of arrays with shardings valid as input parameters to
fundevices (Sequence[Device] | ndarray | None) – optional, a numpy array or list of jax devices specifying the devices to parallelize over. If not given, the function is parallelized over all devices.
- Returns:
A tuple of (sharded_fun, sharded_init_params, sharded_get_input_batch) if
get_input_batchis given or a tuple (sharded_fun, sharded_init_params) if a shardedinitial_batchis given.
Examples
>>> import jax.numpy as jnp >>> import jax >>> from multimesh.jax import parallelize, task, with_sharding_constraint >>> from jax.sharding import PartitionSpec as P >>> >>> def f(x,y): ... x = with_sharding_constraint(x, P("x",)) ... y = with_sharding_constraint(y, P("x",)) ... out = x*y ... return with_sharding_constraint(out, P("batch", "model")) >>> >>> devices = np.array(jax.devices()).reshape(2,2) >>> task_f = task(f, devices=devices, device_axes=("x",)) >>> >>> def init(): >>> return jnp.arange(16) >>> >>> sh_f, sh_init_params, sh_init_batch = parallelize(f, init, init)