elisa.util.config#
Helper functions for computation environment configuration.
- jax_enable_x64(use_x64: bool) None[source]#
Changes the default float precision of arrays in JAX.
- Parameters:
use_x64 (
bool) – WhenTrue, JAX arrays will use 64 bits else 32 bits.
- set_jax_platform(platform: Literal['cpu', 'gpu', 'tpu'] | None = None)[source]#
Set JAX platform to CPU, GPU, or TPU.
Warning
This utility takes effect only before running any JAX program.
- set_cpu_cores(n: int) None[source]#
Set device number to use in JAX.
Warning
This utility takes effect only for CPU platform and before running any JAX program.
- Parameters:
n (
int) – Device number to use.
- jax_debug_nans(flag: bool)[source]#
Automatically detect when NaNs are produced when running JAX codes.
See JAX docs for details.
- Parameters:
flag (
bool) – WhenTrue, raise an error when NaNs are detected in JAX.