elisa.util.config#

Helper functions for computation environment configuration.

jax_enable_x64(use_x64: bool) None[source]#

Changes the default float precision of arrays in JAX.

Parameters:
use_x64bool

When True, JAX arrays will use 64 bits else 32 bits.

set_jax_platform(platform: Literal['cpu', 'gpu', 'tpu'] | None = None)[source]#

Set JAX platform to CPU, GPU, or TPU.

Warning

This utility takes effect only before running any JAX program.

Parameters:
platform{‘cpu’, ‘gpu’, ‘tpu’}, optional

Either 'cpu', 'gpu', or 'tpu'.

set_cpu_cores(n: int) None[source]#

Set device number to use in JAX.

Warning

This utility takes effect only for CPU platform and before running any JAX program.

Parameters:
nint

Device number to use.

jax_debug_nans(flag: bool)[source]#

Automatically detect when NaNs are produced when running JAX codes.

See JAX docs for details.

Parameters:
flagbool

When True, raise an error when NaNs are detected in JAX.

get_parallel_number(n: int | None) int[source]#

Check and return the available parallel number in JAX.

Parameters:
nint, optional

The desired number of parallel processes in JAX.

Returns:
int

The available number of parallel processes.