elisa.util.config#

Helper functions for computation environment configuration.

jax_enable_x64(use_x64: bool) None[source]#

Changes the default float precision of arrays in JAX.

Parameters:

use_x64 (bool) – When True, JAX arrays will use 64 bits else 32 bits.

set_jax_platform(platform: Literal['cpu', 'gpu', 'tpu'] | None = None)[source]#

Set JAX platform to CPU, GPU, or TPU.

Warning

This utility takes effect only before running any JAX program.

Parameters:

platform (Optional[Literal['cpu', 'gpu', 'tpu']]) – Either 'cpu', 'gpu', or 'tpu'.

set_cpu_cores(n: int) None[source]#

Set device number to use in JAX.

Warning

This utility takes effect only for CPU platform and before running any JAX program.

Parameters:

n (int) – Device number to use.

jax_debug_nans(flag: bool)[source]#

Automatically detect when NaNs are produced when running JAX codes.

See JAX docs for details.

Parameters:

flag (bool) – When True, raise an error when NaNs are detected in JAX.

jax_pmap_shmap_merge(flag: bool) Iterator[None][source]#

Temporarily set jax_pmap_shmap_merge and restore it afterwards.

get_parallel_number(n: int | None) int[source]#

Check and return the available parallel number in JAX.

Parameters:

n (int | None) – The desired number of parallel processes in JAX.

Returns:

The available number of parallel processes.