multimesh.jax.enable_recomputation#
- multimesh.jax.enable_recomputation(enable=None, context_value=[True])#
(dis|en)ables comm-avoiding recompute of inter-task intermediates.
Inter-task intermediates create extra runtime overhead and may cause extra inter-task communication. This recomputation can be enabled or disabled for debugging or performance tests. This recomputation is separate and complementary to Jax-level checkpointing and HLO remateralization passes.
- Parameters:
enable (bool | None) – optional, whether to enable inter-task recomputation
context_value – optional, a global variable holding the current context value. The user should never pass this value. The program begins in a context with
enable
True, which means that recomputation is enabled.