Source code for elisa.util.config
"""Helper functions for computation environment configuration."""
from __future__ import annotations
import os
import re
import warnings
from multiprocessing import cpu_count
from typing import TYPE_CHECKING
import jax
if TYPE_CHECKING:
from typing import Literal
[docs]
def jax_enable_x64(use_x64: bool) -> None:
"""Changes the default float precision of arrays in JAX.
Parameters
----------
use_x64 : bool
When ``True``, JAX arrays will use 64 bits else 32 bits.
"""
if not use_x64:
use_x64 = os.getenv('JAX_ENABLE_X64', 0)
jax.config.update('jax_enable_x64', bool(use_x64))
# if platform == 'gpu':
# # see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html
# xla_gpu_flags = (
# '--xla_gpu_enable_triton_softmax_fusion=true '
# '--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
# '--xla_gpu_enable_latency_hiding_scheduler=true '
# '--xla_gpu_enable_highest_priority_async_stream=true'
# )
# xla_flags = os.getenv('XLA_FLAGS', '')
# if xla_gpu_flags not in xla_flags:
# os.environ['XLA_FLAGS'] = f'{xla_flags} {xla_gpu_flags}'
[docs]
def set_cpu_cores(n: int) -> None:
"""Set device number to use in JAX.
.. warning::
This utility takes effect only for CPU platform and before running any
JAX program.
Parameters
----------
n : int
Device number to use.
"""
n = int(n)
total_cores = cpu_count()
if n > total_cores:
msg = f'only {total_cores} CPUs available, '
msg += f'will use {total_cores - 1} CPUs'
warnings.warn(msg, Warning)
n = total_cores - 1
xla_flags = os.getenv('XLA_FLAGS', '')
xla_flags = re.sub(
r'--xla_force_host_platform_device_count=\S+', '', xla_flags
).split()
os.environ['XLA_FLAGS'] = ' '.join(
[f'--xla_force_host_platform_device_count={n}'] + xla_flags
)
[docs]
def jax_debug_nans(flag: bool):
"""Automatically detect when NaNs are produced when running JAX codes.
See JAX `docs <https://jax.readthedocs.io/en/latest/debugging/flags.html>`_
for details.
Parameters
----------
flag : bool
When ``True``, raise an error when NaNs are detected in JAX.
"""
jax.config.update('jax_debug_nans', bool(flag))
[docs]
def get_parallel_number(n: int | None) -> int:
"""Check and return the available parallel number in JAX.
Parameters
----------
n : int, optional
The desired number of parallel processes in JAX.
Returns
-------
int
The available number of parallel processes.
"""
n_max = jax.local_device_count()
if n is None:
return n_max
else:
n = int(n)
if n <= 0:
raise ValueError(
f'number of parallel processes must be positive, got {n}'
)
if n > n_max:
warnings.warn(
f'number of parallel processes ({n}) is more than the number of '
f'available devices ({n_max}), reset to {n_max}',
Warning,
)
n = n_max
return n