"""Miscellaneous helper functions."""
from __future__ import annotations
import math
import re
from functools import reduce
from threading import Lock
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
from astropy.units import Unit
from jax import lax
from jax.custom_derivatives import SymbolicZero
from jax.experimental import io_callback
from jax.flatten_util import ravel_pytree
from prettytable import PrettyTable
from tqdm.auto import tqdm
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Literal, TypeVar
from numpy import ndarray as NDArray
from elisa.util.typing import CompEval
T = TypeVar('T')
UNICODE_SUBSCRIPT = dict(
zip(
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-/=()',
'ᴀʙᴄᴅᴇғɢʜɪᴊᴋʟᴍɴᴏᴘǫʀsᴛᴜᴠᴡxʏᴢₐᵦ𝒸𝒹ₑ𝒻𝓰ₕᵢⱼₖₗₘₙₒₚᵩᵣₛₜᵤᵥ𝓌ₓᵧ𝓏₀₁₂₃₄₅₆₇₈₉₊₋⸝₌₍₎',
strict=True,
)
)
UNICODE_SUPERSCRIPT = dict(
zip(
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-/=()',
'ᴬᴮᶜᴰᴱᶠᴳᴴᴵᴶᴷᴸᴹᴺᴼᴾᵠᴿˢᵀᵁⱽᵂᕽʸᶻᵃᵇᶜᵈᵉᶠᵍʰⁱʲᵏˡᵐⁿᵒᵖᵠʳˢᵗᵘᵛʷˣʸᶻ⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻ᐟ⁼⁽⁾',
strict=True,
)
)
UNICODE_SUFFIX = False
PLAIN_SUBSCRIPT_TEMPLATE = '_%s'
PLAIN_SUPERSCRIPT_TEMPLATE = '[%s]'
LATEX_SUBSCRIPT_TEMPLATE = '_{%s}'
LATEX_SUPERSCRIPT_TEMPLATE = '^{%s}'
[docs]
def set_suffix_config(
unicode: bool = False,
plain_subscript_template: str = '_%s',
plain_superscript_template: str = '[%s]',
latex_subscript_template: str = '_{%s}',
latex_superscript_template: str = '^{%s}',
) -> None:
"""Set suffix configuration.
Parameters
----------
unicode : bool, optional
If True, use unicode suffix. The default is False.
plain_subscript_template : str, optional
The template for plain subscript. The default is ``'_%s'``.
plain_superscript_template : str, optional
The template for plain superscript. The default is ``'[%s]'``.
latex_subscript_template : str, optional
The template for LaTeX subscript. The default is ``'_{%s}'``.
latex_superscript_template : str, optional
The template for LaTeX superscript. The default is ``'^{%s}'``.
"""
# test if templates are valid
try:
_ = plain_subscript_template % 'test'
_ = plain_superscript_template % 'test'
_ = latex_subscript_template % 'test'
_ = latex_superscript_template % 'test'
except TypeError as e:
raise ValueError(
'subscript and superscript templates must be valid '
'Python string formatting templates'
) from e
global UNICODE_SUFFIX
global PLAIN_SUBSCRIPT_TEMPLATE
global PLAIN_SUPERSCRIPT_TEMPLATE
global LATEX_SUBSCRIPT_TEMPLATE
global LATEX_SUPERSCRIPT_TEMPLATE
UNICODE_SUFFIX = bool(unicode)
PLAIN_SUBSCRIPT_TEMPLATE = plain_subscript_template
PLAIN_SUPERSCRIPT_TEMPLATE = plain_superscript_template
LATEX_SUBSCRIPT_TEMPLATE = latex_subscript_template
LATEX_SUPERSCRIPT_TEMPLATE = latex_superscript_template
[docs]
def add_suffix(
strings: str | Sequence[str],
suffix: str | Sequence[str],
subscript: bool,
unicode: bool | None = None,
latex: bool = False,
mathrm: bool = False,
) -> str | list[str]:
"""Add suffix to a sequence of strings.
Parameters
----------
strings : sequence of str
The sequence of strings.
suffix : sequence of str
The sequence of suffixes. The suffix format can be set by
:py:func:`elisa.util.misc.set_suffix_config`.
subscript : bool, optional
If True, add suffix as subscript, otherwise superscript.
The default is True.
latex : bool, optional
If True, add suffix following LaTeX format. The default is False.
unicode : bool, optional
If True, add suffix with Unicode string. Defaults to
``elisa.util.misc.UNICODE_SUFFIX``, which can be set by
:py:func:`elisa.util.misc.set_suffix_config`.
mathrm : bool, optional
If True, add suffix in mathrm when latex is True. The default is False.
Returns
-------
str or list of str
The strings with suffix added.
"""
if unicode is None:
unicode = UNICODE_SUFFIX
return_list = False
if isinstance(strings, str):
strings = [strings]
else:
strings = list(strings)
return_list = True
if isinstance(suffix, str):
suffix = [suffix]
else:
suffix = list(suffix)
return_list = True
if len(strings) != len(suffix):
raise ValueError('length of `strings` and `suffix` must be the same')
def to_unicode(string: str):
"""Replace suffix with unicode."""
if subscript:
return ''.join(f'{UNICODE_SUBSCRIPT.get(i, i)}' for i in string)
else:
return ''.join(f'{UNICODE_SUPERSCRIPT.get(i, i)}' for i in string)
if latex:
if subscript:
template = LATEX_SUBSCRIPT_TEMPLATE
else:
template = LATEX_SUPERSCRIPT_TEMPLATE
if mathrm:
suffix = [r'\mathrm{' + i + '}' if i else '' for i in suffix]
else:
suffix = ['{' + i + '}' if i else '' for i in suffix]
strings = ['{' + i + '}' for i in strings]
elif unicode:
template = '%s'
suffix = [to_unicode(i) for i in suffix]
else:
if subscript:
template = PLAIN_SUBSCRIPT_TEMPLATE
else:
template = PLAIN_SUPERSCRIPT_TEMPLATE
strings = [
i + template % j if j else i
for i, j in zip(strings, suffix, strict=True)
]
if return_list:
return strings
else:
return strings[0]
[docs]
def build_namespace(
names: Sequence[str],
latex: bool = False,
prime: bool = False,
) -> dict[str, list[str | int]]:
"""Build a namespace from a sequence of names.
Parameters
----------
names : sequence of str
A sequence of names.
latex : bool, optional
If True, `names` are assumed to be LaTeX strings. The default is False.
prime : bool, optional
If True, primes are used as suffix for duplicate names, otherwise
a number is used. The default is False.
Returns
-------
namespace: dict
A dict of non-duplicate names and suffixes in original name order.
"""
namespace = []
suffixes_n = []
counter = {}
for name in names:
if name not in namespace:
counter[name] = 1
namespace.append(name)
else:
counter[name] += 1
suffixes_n.append(counter[name])
if prime:
suffixes = [i - 1 for i in suffixes_n]
suffixes = ["'" * n for n in suffixes]
else:
template = '_{%d}' if latex else '_%d'
suffixes = [template % n if n > 1 else '' for n in suffixes_n]
return {
'namespace': list(map(''.join, zip(names, suffixes, strict=True))),
'suffix_num': [str(n) if n > 1 else '' for n in suffixes_n],
}
[docs]
def define_fdjvp(
fn: CompEval,
method: Literal['central', 'forward'] = 'central',
) -> CompEval:
"""Define JVP using finite differences."""
if method not in {'central', 'forward'}:
raise ValueError(
f"supported methods are 'central' and 'forward', but got "
f"'{method}'"
)
def fdjvp(primals, tangents):
egrid, params = primals
egrid_tangent, params_tangent = tangents
if not isinstance(egrid_tangent, SymbolicZero):
raise NotImplementedError('JVP for energy grid is not implemented')
primals_out = fn(egrid, params)
tvals, _ = jax.tree.flatten(params_tangent)
if any(jnp.shape(v) != () for v in tvals):
raise NotImplementedError(
'JVP for non-scalar parameter is not implemented'
)
non_zero_tangents = [not isinstance(v, SymbolicZero) for v in tvals]
idx = [i for i, v in enumerate(non_zero_tangents) if v]
idx_arr = jnp.array(idx)
nbatch = sum(non_zero_tangents)
nparam = len(tvals)
params_ravel, revert = ravel_pytree(params)
free_params_values = params_ravel[idx_arr]
free_params_abs = jnp.where(
jnp.equal(free_params_values, 0.0),
jnp.ones_like(free_params_values),
jnp.abs(free_params_values),
)
free_params_abs = jnp.expand_dims(free_params_abs, axis=-1)
row_idx = jnp.arange(nbatch)
perturb_idx = jnp.zeros((nbatch, nparam)).at[row_idx, idx_arr].set(1.0)
params_batch = jnp.full((nbatch, nparam), params_ravel)
eps = jnp.finfo(egrid.dtype).eps
f_vmap = jax.vmap(fn, in_axes=(None, 0), out_axes=0)
revert = jax.vmap(revert, in_axes=0, out_axes=0)
# See Numerical Recipes Chapter 5.7
if method == 'central':
perturb = free_params_abs * eps ** (1.0 / 3.0)
params_pos_perturb = revert(params_batch + perturb_idx * perturb)
out_pos_perturb = f_vmap(egrid, params_pos_perturb)
params_neg_perturb = revert(params_batch - perturb_idx * perturb)
out_neg_perturb = f_vmap(egrid, params_neg_perturb)
d_out = (out_pos_perturb - out_neg_perturb) / (2.0 * perturb)
else:
perturb = free_params_abs * jnp.sqrt(eps)
params_perturb = revert(params_batch + perturb_idx * perturb)
out_perturb = f_vmap(egrid, params_perturb)
d_out = (out_perturb - primals_out) / perturb
free_params_tangent = jnp.array([tvals[i] for i in idx])
tangents_out = free_params_tangent @ d_out
return primals_out, tangents_out
fn = jax.custom_jvp(fn)
fn.defjvp(fdjvp, symbolic_zeros=True)
return fn
[docs]
def get_unit_latex(unit: str, throw: bool = True) -> str:
"""Get latex string of a unit.
Parameters
----------
unit : str
The unit string.
throw : bool, optional
If True, raise ValueError if the unit is invalid. The default is True.
Returns
-------
str
The latex string of the unit.
"""
ustr = str(unit)
if ustr:
try:
unit = Unit(ustr)
max_index = len(ustr)
pattern = r'(?:[^a-zA-Z]*){}(?:[^a-zA-Z]*)'
index = [
min(
(
r.start(0)
if (r := re.search(pattern.format(s), ustr))
is not None
else max_index
)
for s in [b.name] + b.aliases
)
for b in unit.bases
]
index = sorted(range(len(index)), key=index.__getitem__)
bases = [unit.bases[i].name for i in index]
powers = [unit.powers[i] for i in index]
ustr = r'\ '.join(
b + (f'^{{{p}}}' if p != 1 else '')
for b, p in zip(bases, powers, strict=True)
)
scale = Unit(unit.scale).to_string('latex_inline')[9:-2]
if scale != '':
scale = scale.replace(r'1 \times ', '')
scale += r'\ '
ustr = rf'$\mathrm{{{scale}{ustr}}}$'
except ValueError as ve:
if throw:
raise ve
ustr = ''
return ustr
[docs]
def make_pretty_table(fields: Sequence[str], rows: Sequence) -> PrettyTable:
"""Make a :class:`prettytable.PrettyTable`.
Parameters
----------
fields : sequence of str
The names of fields.
rows : sequence
The sequence of data corresponding to the `fields`.
Returns
-------
table : PrettyTable
The pretty table.
"""
table = PrettyTable(
fields,
align='c',
hrules=1, # 1 for all, 0 for frame
vrules=1,
padding_width=1,
vertical_char='│',
horizontal_char='─',
junction_char='┼',
top_junction_char='┬',
bottom_junction_char='┴',
right_junction_char='┤',
left_junction_char='├',
top_right_junction_char='╮',
top_left_junction_char='╭',
bottom_right_junction_char='╯',
bottom_left_junction_char='╰',
)
table.add_rows(rows)
return table
[docs]
def replace_string(value: T, mapping: dict[str, str]) -> T:
"""Replace all strings in `value` appeared in `mapping`.
Parameters
----------
value : str, iterable or mapping
Object whose str value needs to be replaced.
mapping : dict
Mapping of str value to be replaced and replacement.
Returns
-------
replaced : iterable or mapping
Value of `value` replaced with `mapping`.
"""
mapping = mapping.items()
def replace_with_mapping(s: str):
"""Replace all k in s with v, as in mapping."""
return reduce(lambda x, kv: x.replace(*kv), mapping, s)
def replace_dict(d: dict):
"""Replace key and value of a dict."""
return {replace(k): replace(v) for k, v in d.items()}
def replace_sequence(it: tuple | list):
"""Replace elements of iterable."""
return type(it)(map(replace, it))
def replace(v):
"""The main replace function."""
if isinstance(v, str):
return replace_with_mapping(v)
elif isinstance(v, list | tuple):
return replace_sequence(v)
elif isinstance(v, dict):
return replace_dict(v)
else:
return v
return replace(value)
[docs]
def report_interval(
vmid: float,
vmin: float,
vmax: float,
precision: int = 2,
min_exponent: int = 1,
max_exponent: int = 2,
) -> str:
r"""Report parameter interval in :math:`\LaTeX` format.
Parameters
----------
vmid : float
The mid value.
vmin : float
The lower bound.
vmax : float
The upper bound.
precision : int, optional
The precision of the mid value. The default is 2.
min_exponent : int, optional
The minimum exponent to use scientific notation. The default is 1.
max_exponent : int, optional
The maximum exponent to use scientific notation. The default is 2.
Returns
-------
str
The interval in :math:`|LaTeX` format.
"""
vmid = float(vmid)
vmin = float(vmin)
vmax = float(vmax)
precision = int(precision)
min_exponent = int(min_exponent)
max_exponent = int(max_exponent)
# assert vmin <= vmid <= vmax
assert precision > 0
assert min_exponent > 0
assert max_exponent > 0
def get_sci_notation_exponent(num: float) -> int:
"""Get the exponent of a number in scientific notation."""
return int(f'{num:.{precision}e}'.split('e')[1])
def get_sci_notation_significand(num: float, exp: int) -> str:
"""Get the significand of a number in scientific notation."""
significand = num * 10**-exp
if abs(num) < 10 ** (exp - precision):
p = abs(get_sci_notation_exponent(num) - exp)
return f'{significand:+.{p}f}'.rstrip('0')
else:
p = precision
return f'{significand:+.{p}f}'
lower = vmin - vmid
upper = vmax - vmid
exponent = math.log10(abs(vmid))
if exponent <= -min_exponent or exponent >= max_exponent:
str_mid = f'{vmid:.{precision}e}'.split('e')[0]
base_exponent = math.floor(exponent)
suffix = rf' \times 10^{{{base_exponent}}}'
else:
str_mid = f'{vmid:.{precision}f}'
base_exponent = 0
suffix = ''
if lower != 0:
str_lower = get_sci_notation_significand(lower, base_exponent)
else:
str_lower = '-0'
if upper != 0:
str_upper = get_sci_notation_significand(upper, base_exponent)
else:
str_upper = '+0'
return f'${str_mid}_{{{str_lower}}}^{{{str_upper}}}{suffix}$'
[docs]
def progress_bar_factory(
neval: int,
ncores: int,
init_str: str | None = None,
run_str: str | None = None,
update_rate: int = 50,
) -> Callable[[Callable], Callable]:
"""Add a progress bar to JAX ``fori_loop`` kernel, see [1]_ for details.
Parameters
----------
neval : int
The total number of evaluations.
ncores : int
The number of cores.
init_str : str, optional
The string displayed before progress bar when initialization.
run_str : str, optional
The string displayed before progress bar when run.
update_rate : int, optional
The update rate of the progress bar. The default is 50.
Returns
-------
progress_bar_fori_loop : callable
Factory that adds a progress bar to function input.
References
----------
.. [1] `How to add a progress bar to JAX scans and loops
<https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/>`_
"""
neval = int(neval)
ncores = int(ncores)
neval_single = neval // ncores
if neval % ncores != 0:
raise ValueError('neval must be multiple of ncores')
if init_str is None:
init_str = 'Compiling... '
else:
init_str = str(init_str)
if run_str is None:
run_str = 'Running'
else:
run_str = str(run_str)
if neval > update_rate:
print_rate = max(1, int(neval_single / update_rate))
else:
print_rate = 1
# lock serializes access to idx_counter since callbacks are multithreaded
lock = Lock()
idx_counter = 0 # resource counter
remainder = neval_single % print_rate
bar = tqdm(range(neval))
bar.set_description(init_str, refresh=True)
def _update_tqdm(increment):
bar.set_description(run_str, refresh=False)
bar.update(int(increment))
def _close_tqdm():
nonlocal idx_counter
bar.update(remainder)
with lock:
idx_counter += 1
if idx_counter == ncores:
bar.close()
def _update_progress_bar(iter_num):
_ = lax.cond(
iter_num == 1,
lambda _: io_callback(_update_tqdm, None, 0),
lambda _: None,
operand=None,
)
_ = lax.cond(
iter_num % print_rate == 0,
lambda _: io_callback(_update_tqdm, None, print_rate),
lambda _: None,
operand=None,
)
_ = lax.cond(
iter_num == neval_single,
lambda _: io_callback(_close_tqdm, None),
lambda _: None,
operand=None,
)
def progress_bar_fori_loop(fn):
"""Decorator that adds a progress bar to `body_fun` used in
`lax.fori_loop`.
Note that `body_fun` must be looping over a tuple who's first element
is `np.arange(num_samples)`.
This means that `iter_num` is the current iteration number
"""
def _wrapper_progress_bar(i, vals):
result = fn(i, vals)
_update_progress_bar(i + 1)
return result
return _wrapper_progress_bar
return progress_bar_fori_loop
[docs]
def to_native_byteorder(arr: NDArray) -> NDArray:
"""Convert an array to native byte order."""
if arr.dtype.byteorder != '=':
return arr.astype(arr.dtype.newbyteorder('='))
else:
return arr