multimesh.jax.enable_recomputation

multimesh.jax.enable_recomputation#

multimesh.jax.enable_recomputation(enable=None, context_value=[True])#

(dis|en)ables comm-avoiding recompute of inter-task intermediates.

Inter-task intermediates create extra runtime overhead and may cause extra inter-task communication. This recomputation can be enabled or disabled for debugging or performance tests. This recomputation is separate and complementary to Jax-level checkpointing and HLO remateralization passes.

Parameters:
  • enable (bool | None) – optional, whether to enable inter-task recomputation

  • context_value – optional, a global variable holding the current context value. The user should never pass this value. The program begins in a context with enable True, which means that recomputation is enabled.