"""The spectral model bases."""
from __future__ import annotations
import inspect
import warnings
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from functools import reduce
from typing import TYPE_CHECKING, NamedTuple
import astropy.units as u
import jax
import jax.numpy as jnp
import numpy as np
from astropy.cosmology import Planck18
from scipy.sparse import sparray
from elisa.data.base import ObservationData, ResponseData, SpectrumData
from elisa.models.parameter import Parameter, UniformParameter
from elisa.util.misc import build_namespace, define_fdjvp, make_pretty_table
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any, Literal
from astropy.cosmology import LambdaCDM
from numpyro.distributions import Distribution
from elisa.models.parameter import ParamInfo
from elisa.util.integrate import IntegralFactory
from elisa.util.typing import (
AdditiveFn,
ArrayLike,
CompEval,
CompID,
CompIDParamValMapping,
CompIDStrMapping,
CompName,
CompParamName,
ConvolveEval,
JAXArray,
JAXFloat,
ModelCompiledFn,
ModelEval,
NameValMapping,
ParamID,
ParamIDStrMapping,
ParamIDValMapping,
ParamNameValMapping,
)
NDArray = np.ndarray
[docs]
class Model(ABC):
"""Base model class."""
_comps: tuple[Component, ...]
__initialized: bool = False
def __init__(
self,
name: str,
latex: str,
comps: list[Component],
):
self._name = str(name)
self._latex = str(latex)
self._comps = tuple(comps)
self._comps_id = tuple(comp._id for comp in comps)
cname = build_namespace([c.name for c in comps])['namespace']
cid_to_cname = dict(zip(self._comps_id, cname, strict=True))
self._cid_to_cname = cid_to_cname
self.__name = self._id_to_label(cid_to_cname, 'name')
for name, comp in zip(cid_to_cname.values(), comps, strict=True):
setattr(self, name, comp)
self.__initialized = True
@property
def _cid_to_clatex(self) -> CompIDStrMapping:
clatex = [c.latex for c in self._comps]
clatex = build_namespace(clatex, latex=True)['namespace']
return dict(zip(self._comps_id, clatex, strict=True))
[docs]
def compile(self, *, model_info: ModelInfo | None = None) -> CompiledModel:
"""Compile the model for fast evaluation.
Parameters
----------
model_info : ModelInfo, optional
Optional model information used to compile the model.
Returns
-------
CompiledModel
The compiled model.
"""
if self.type == 'conv':
raise RuntimeError(f'cannot compile convolution model {self.name}')
if model_info is None:
model_info = get_model_info(
self._comps, self._cid_to_cname, self._cid_to_clatex
)
else:
if not isinstance(model_info, ModelInfo):
raise TypeError('`model_info` must be a ModelInfo instance')
if not set(self._comps_id).issubset(set(model_info.cid_to_params)):
raise ValueError('inconsistent model information')
# model name
name = self._id_to_label(model_info.cid_to_name, 'name')
# model parameter id
params_id = []
fixed_id = []
for c in self._comps:
for p in c.param_names:
for pid in c[p]._nodes_id:
if pid in model_info.sample:
params_id.append(pid)
elif pid in model_info.fixed:
fixed_id.append(pid)
# compiled model evaluation function
fn = self._compile_model_fn(model_info)
additive_fn = self._compile_additive_fn(model_info)
# model type
mtype = self.type
return CompiledModel(
name, params_id, fixed_id, fn, additive_fn, mtype, model_info
)
@property
@abstractmethod
def eval(self) -> ModelEval:
"""Get side-effect free model evaluation function."""
pass
@property
def name(self) -> str:
"""Model name."""
return self.__name
@property
def latex(self) -> str:
r""":math:`\LaTeX` format of the model."""
latex = [c.latex for c in self._comps]
cid_to_latex = dict(
zip(
self._comps_id,
build_namespace(latex, latex=True)['namespace'],
strict=True,
)
)
return self._id_to_label(cid_to_latex, 'latex')
@property
@abstractmethod
def type(self) -> Literal['add', 'mul', 'conv']:
"""Model type."""
pass
@property
def comp_names(self) -> tuple[CompName, ...]:
"""Component names."""
return tuple(self._cid_to_cname.values())
@property
def _additive_comps(self) -> tuple[Model, ...]:
if self.type != 'add':
return ()
else:
return (self,)
def _id_to_label(
self,
mapping: CompIDStrMapping,
label_type: Literal['name', 'latex'],
) -> str:
"""Get the label of the parameter."""
if label_type not in {'name', 'latex'}:
raise ValueError(f'unknown label type: {label_type}')
label = self._name if label_type == 'name' else self._latex
for k, v in mapping.items():
label = label.replace(k, v)
return label
def _compile_model_fn(self, model_info: ModelInfo) -> ModelCompiledFn:
"""Get the model evaluation function."""
pid_to_value = {
c._id: model_info.cid_to_params[c._id] for c in self._comps
}
eval_fn = jax.jit(self.eval)
@jax.jit
def fn(egrid: JAXArray, params: ParamIDValMapping) -> JAXArray:
"""The model evaluation function"""
comps_params = jax.tree.map(lambda f: f(params), pid_to_value)
return eval_fn(egrid, comps_params)
for integrate in model_info.integrate.values():
fn = jax.jit(integrate(fn))
return fn
def _compile_additive_fn(self, model_info: ModelInfo) -> AdditiveFn | None:
"""Get the additive model evaluation function."""
additive_comps = self._additive_comps
if len(additive_comps) <= 1:
return None
fns = [comp._compile_model_fn(model_info) for comp in additive_comps]
comps_labels = [
comp._id_to_label(model_info.cid_to_latex, 'latex')
for comp in additive_comps
]
@jax.jit
def _fn(egrid: JAXArray, params: ParamIDValMapping) -> JAXArray:
return [f(egrid, params) for f in fns]
@jax.jit
def fn(
egrid: JAXArray, params: ParamIDValMapping
) -> dict[tuple[str, str], JAXArray]:
"""The evaluation function of additive components."""
return dict(zip(comps_labels, _fn(egrid, params), strict=True))
return fn
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
fields = ['No.', 'Component', 'Parameter', 'Value', 'Bound', 'Prior']
info = get_model_info(
self._comps, self._cid_to_cname, self._cid_to_clatex
).info
tab_params = make_pretty_table(fields, info)
mtype = get_mtype_str(self)
return f'{mtype} Model: {self.name}\n{tab_params.get_string()}'
def _repr_html_(self) -> str:
fields = ['No.', 'Component', 'Parameter', 'Value', 'Bound', 'Prior']
info = get_model_info(
self._comps, self._cid_to_cname, self._cid_to_clatex
).info
tab_params = make_pretty_table(fields, info)
mtype = get_mtype_str(self)
return (
'<details open>'
f'<summary><b>{mtype} Model: {self.name}</b></summary>'
f'{tab_params.get_html_string(format=True)}'
'</details>'
)
def __delattr__(self, key: str):
if self.__initialized and hasattr(self, key):
raise AttributeError("can't delete attribute")
super().__delattr__(key)
def __setattr__(self, key: str, value: Any):
if self.__initialized and (
not hasattr(self, key) or key in self._cid_to_cname.values()
):
raise AttributeError("can't set attribute")
super().__setattr__(key, value)
def __getitem__(self, key: str) -> Component:
if key not in self._cid_to_cname.values():
raise KeyError(key)
return getattr(self, key)
def __add__(self, other: Model) -> CompositeModel:
return CompositeModel(self, other, '+')
def __radd__(self, other: Model) -> CompositeModel:
return CompositeModel(other, self, '+')
def __mul__(self, other: Model) -> CompositeModel:
return CompositeModel(self, other, '*')
def __rmul__(self, other: Model) -> CompositeModel:
return CompositeModel(other, self, '*')
[docs]
class CompiledModel:
"""Model with fast evaluation and fixed configuration."""
__initialized: bool = False
def __init__(
self,
name: str,
params_id: Sequence[ParamID],
fixed_id: Sequence[ParamID],
fn: ModelCompiledFn,
additive_fn: AdditiveFn | None,
mtype: Literal['add', 'mul'],
model_info: ModelInfo,
):
pname_to_pid = {model_info.name[pid]: pid for pid in params_id}
self.name = name
self._pname_to_pid = pname_to_pid
self._params_name = tuple(pname_to_pid)
self._params_id = params_id
self._params_default = dict(model_info.default)
self._value_sequence_to_params: Callable[
[Sequence[JAXFloat]], ParamIDValMapping
] = jax.jit(
lambda sequence: dict(zip(params_id, sequence, strict=True))
)
fixed_name_to_pid = {model_info.name[pid]: pid for pid in fixed_id}
pname_pid_all = pname_to_pid | fixed_name_to_pid
self._value_mapping_to_params: Callable[
[ParamNameValMapping], ParamIDValMapping
] = jax.jit(
lambda mapping: {
v: mapping[k] for k, v in pname_pid_all.items() if k in mapping
}
)
self._fn = fn
self._additive_fn = additive_fn
self._type = mtype
self._model_info = model_info
self._nparam = len(pname_to_pid)
self.__initialized = True
@property
def params_name(self) -> tuple[str, ...]:
"""Parameter names."""
return self._params_name
@property
def type(self) -> Literal['add', 'mul']:
"""Model type."""
return self._type
@property
def has_comps(self) -> bool:
"""Whether the model has additive subcomponents."""
return self._additive_fn is not None
def _prepare_eval(self, params: ArrayLike | Sequence | Mapping | None):
"""Check if `params` is valid for the model."""
if isinstance(params, np.ndarray | jax.Array | Sequence):
params = jnp.atleast_1d(jnp.asarray(params, float))
if params.shape[-1] != self._nparam:
raise ValueError(
f'expected params shape (..., {self._nparam}), got '
f'{jnp.shape(params)}'
)
params = jnp.moveaxis(params, -1, 0)
params = self._value_sequence_to_params(params)
elif isinstance(params, Mapping):
# if not set(self.params_name).issubset(params):
# missing = set(self.params_name) - set(params)
# raise ValueError(f'missing parameters: {", ".join(missing)}')
params = jax.tree.map(jnp.asarray, params)
params = self._value_mapping_to_params(params)
elif params is None:
params = self._params_default
else:
raise TypeError('params must be a array, sequence or mapping')
fn = self._fn
add_fn = self._additive_fn
shapes = jax.tree.flatten(
tree=jax.tree.map(jnp.shape, params),
is_leaf=lambda i: isinstance(i, tuple),
)[0]
if self.params_name:
if not shapes:
raise ValueError('params are empty')
shape = shapes[0]
if any(s != shape for s in shapes[1:]):
raise ValueError('all params must have the same shape')
# iteratively vmap and jit over params dimensions
# use the nested-jit trick to reduce the compilation time
for _ in range(len(shape)):
fn = jax.jit(jax.vmap(fn, in_axes=(None, 0)))
if add_fn is not None:
for _ in range(len(shape)):
add_fn = jax.jit(jax.vmap(add_fn, in_axes=(None, 0)))
return fn, add_fn, params
[docs]
def eval(
self,
egrid: ArrayLike,
params: Sequence | Mapping | None = None,
) -> JAXArray:
"""Evaluate the model.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
Returns
-------
jax.Array
The model value.
"""
f, _, params = self._prepare_eval(params)
return f(jnp.asarray(egrid, float), params)
[docs]
def ne(
self,
egrid: ArrayLike,
params: Sequence | Mapping | None = None,
comps: bool = False,
) -> JAXArray | dict[str, JAXArray]:
r"""Calculate :math:`N(E)`.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
comps : bool, optional
Whether to return the result of each component. Defaults to False.
Returns
-------
jax.Array, or dict[str, jax.Array]
The differential photon flux in units of ph cm⁻² s⁻¹ keV⁻¹.
"""
if self.type != 'add':
msg = f'N(E) is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
if comps and not self.has_comps:
raise RuntimeError(f'{self} has no sub-models with additive type')
egrid = jnp.asarray(egrid, float)
de = egrid[1:] - egrid[:-1]
if comps:
_, additive_fn, params = self._prepare_eval(params)
comps_value = additive_fn(egrid, params)
ne = jax.tree.map(lambda v: v / de, comps_value)
else:
ne = self.eval(egrid, params) / de
return ne
[docs]
def ene(
self,
egrid: ArrayLike,
params: Sequence | Mapping | None = None,
comps: bool = False,
) -> JAXArray | dict[str, JAXArray]:
r"""Calculate :math:`E N(E)`, i.e. :math:`F(\nu)`.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
comps : bool, optional
Whether to return the result of each component. Defaults to False.
Returns
-------
jax.Array, or dict[str, jax.Array]
The differential energy flux in units of erg cm⁻² s⁻¹ keV⁻¹.
"""
if self.type != 'add':
msg = f'EN(E) is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
if comps and not self.has_comps:
raise RuntimeError(f'{self} has no sub-models with additive type')
keV_to_erg = 1.602176634e-9
egrid = jnp.asarray(egrid, float)
emid = jnp.sqrt(egrid[:-1] * egrid[1:])
factor = keV_to_erg * emid
ne = self.ne(egrid, params, comps)
if comps:
ene = jax.tree.map(lambda v: factor * v, ne)
else:
ene = factor * ne
return ene
[docs]
def eene(
self,
egrid: ArrayLike,
params: Sequence | Mapping | None = None,
comps: bool = False,
) -> JAXArray | dict[str, JAXArray]:
r"""Calculate :math:`E^2 N(E)`, i.e. :math:`\nu F(\nu)`.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
comps : bool, optional
Whether to return the result of each component. Defaults to False.
Returns
-------
jax.Array, or dict[str, jax.Array]
The energy flux in units of erg cm⁻² s⁻¹.
"""
if self.type != 'add':
msg = f'EEN(E) is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
if self._additive_fn is None and comps:
raise RuntimeError(f'{self} has no sub-models with additive type')
keV_to_erg = 1.602176634e-9
egrid = jnp.asarray(egrid, float)
e2 = egrid[:-1] * egrid[1:]
factor = keV_to_erg * e2
ne = self.ne(egrid, params, comps)
if comps:
eene = jax.tree.map(lambda v: factor * v, ne)
else:
eene = factor * ne
return eene
[docs]
def ce(
self,
egrid: ArrayLike,
resp_matrix: ArrayLike,
channel_width: ArrayLike,
params: Sequence | Mapping | None = None,
comps: bool = False,
) -> JAXArray | dict[str, JAXArray]:
r"""Calculate the folded model, i.e. :math:`C(E)`.
Parameters
----------
egrid : ndarray
Photon energy grid of `resp_matrix`, in units of keV.
resp_matrix : ndarray
Instrumental response matrix used to fold the model.
channel_width : ndarray
Measured energy channel width of `resp_matrix`.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
comps : bool, optional
Whether to return the result of each component. Defaults to False.
Returns
-------
jax.Array, or dict[str, jax.Array]
The folded model in units of counts s⁻¹ keV⁻¹.
"""
if self.type != 'add':
msg = f'C(E) is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
if comps and not self.has_comps:
raise RuntimeError(f'{self} has no sub-models with additive type')
egrid = jnp.asarray(egrid, float)
resp_matrix = jnp.asarray(resp_matrix, float)
channel_width = jnp.asarray(channel_width, float)
ne = self.ne(egrid, params, comps)
de = egrid[1:] - egrid[:-1]
fn = jax.jit(lambda v: (v * de) @ resp_matrix / channel_width)
if comps:
return jax.tree.map(fn, ne)
else:
return fn(ne)
[docs]
def flux(
self,
emin: float | int | JAXFloat,
emax: float | int | JAXFloat,
params: Sequence | Mapping | None = None,
energy: bool = True,
comps: bool = False,
ngrid: int = 1000,
log: bool = True,
) -> jax.Array | dict[str, jax.Array]:
r"""Calculate the flux of model between `emin` and `emax`.
.. warning::
The flux is calculated by trapezoidal rule, which may not be
accurate if not enough energy bins are used when the difference
between `emin` and `emax` is large.
Parameters
----------
emin : float or int
Minimum value of energy range, in units of keV.
emax : float or int
Maximum value of energy range, in units of keV.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
energy : bool, optional
When True, calculate energy flux in units of erg cm⁻² s⁻¹;
otherwise calculate photon flux in units of ph cm⁻² s⁻¹.
The default is True.
comps : bool, optional
Whether to return the result of each component. The default is
False.
ngrid : int, optional
The energy grid number to use in integration. The default is 1000.
log : bool, optional
Whether to use logarithmically regular energy grid. The default is
True.
Returns
-------
jax.Array, or dict[str, jax.Array]
The model flux.
"""
if self.type != 'add':
msg = f'flux is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
if log:
egrid = jnp.geomspace(emin, emax, ngrid)
else:
egrid = jnp.linspace(emin, emax, ngrid)
if energy:
f = self.ene(egrid, params, comps)
else:
f = self.ne(egrid, params, comps)
de = jnp.diff(egrid)
fn = jax.jit(lambda v: jnp.sum(v * de, axis=-1))
if comps:
return jax.tree.map(fn, f)
else:
return fn(f)
[docs]
def lumin(
self,
emin_rest: float | int,
emax_rest: float | int,
z: float | int,
params: Sequence | Mapping | None = None,
comps: bool = False,
ngrid: int = 1000,
log: bool = True,
cosmo: LambdaCDM = Planck18,
) -> JAXArray | dict[str, JAXArray]:
"""Calculate the luminosity of model.
.. warning::
The luminosity is calculated by trapezoidal rule, and is accurate
only if enough numbers of energy grids are used.
Parameters
----------
emin_rest : float or int
Minimum value of rest-frame energy range, in units of keV.
emax_rest : float or int
Maximum value of rest-frame energy range, in units of keV.
z : float or int
Redshift of the source.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
comps : bool, optional
Whether to return the result of each component. The default is
False.
ngrid : int, optional
The energy grid number to use in integration. The default is 1000.
log : bool, optional
Whether to use logarithmically regular energy grid. The default is
True.
cosmo : LambdaCDM, optional
Cosmology model used to calculate luminosity. The default is
Planck18.
Returns
-------
jax.Array, or dict[str, jax.Array]
The model luminosity.
"""
if self.type != 'add':
msg = f'lumin is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
z = float(z)
flux = self.flux(
emin=float(emin_rest) / (1.0 + z),
emax=float(emax_rest) / (1.0 + z),
params=params,
energy=True,
comps=bool(comps),
ngrid=int(ngrid),
log=bool(log),
)
flux_unit = u.Unit('erg cm^-2 s^-1')
factor = 4.0 * np.pi * cosmo.luminosity_distance(z) ** 2
to_lumin = lambda x: (x * flux_unit * factor).to('erg s^-1')
if comps:
return jax.tree.map(to_lumin, flux)
else:
return to_lumin(flux)
[docs]
def eiso(
self,
emin_rest: float | int,
emax_rest: float | int,
z: float | int,
duration: float | int,
params: Sequence | Mapping | None = None,
comps: bool = False,
ngrid: int = 1000,
log: bool = True,
cosmo: LambdaCDM = Planck18,
) -> JAXArray | dict[str, JAXArray]:
r"""Calculate the isotropic emission energy of model.
.. warning::
The :math:`E_\mathrm{iso}` is calculated by trapezoidal rule,
and is accurate only if enough numbers of energy grids are used.
Parameters
----------
emin_rest : float or int
Minimum value of rest-frame energy range, in units of keV.
emax_rest : float or int
Maximum value of rest-frame energy range, in units of keV.
z : float or int
Redshift of the source.
duration : float or int
Observed duration of the source, in units of second.
comps : bool, optional
Whether to return the result of each component. The default is
False.
ngrid : int, optional
The energy grid number to use in integration. The default is
1000.
log : bool, optional
Whether to use logarithmically regular energy grid. The default
is True.
params : dict, optional
Parameters dict to overwrite the fitted parameters.
cosmo : LambdaCDM, optional
Cosmology model used to calculate luminosity. The default is
Planck18.
Returns
-------
jax.Array, or dict[str, jax.Array]
The isotropic emission energy of the model.
"""
if self.type != 'add':
msg = f'eiso is undefined for {self.type} type model "{self}"'
raise TypeError(msg)
lumin = self.lumin(
emin_rest, emax_rest, z, params, comps, ngrid, log, cosmo
)
# This includes correction for time dilation.
factor = duration / (1 + z) * u.s
to_eiso = lambda x: (x * factor).to('erg')
if comps:
return jax.tree.map(to_eiso, lumin)
else:
return to_eiso(lumin)
[docs]
def simulate_from_data(
self,
data: ObservationData,
spec_exposure: float | None = None,
back_exposure: float | None = None,
params: Sequence | Mapping | None = None,
seed: int = 42,
**kwargs: dict,
) -> ObservationData:
"""Simulate observation based on the configuration of existing data.
Parameters
----------
data : ObservationData
Observation data to read observation configuration.
spec_exposure : float, optional
Exposure time of the source. Defaults to the exposure time of
`spec_data`.
back_exposure : float, optional
Exposure time of the background. Defaults to the exposure time of
`back_data`.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
seed : int, optional
Random seed for the simulation. The default is 42.
**kwargs : dict, optional
Additional keyword arguments passed to :class:`ObservationData`.
Returns
-------
ObservationData
Simulated observation data.
"""
if self.type != 'add':
msg = f'cannot simulate data from {self.type} type model "{self}"'
raise TypeError(msg)
if not isinstance(data, ObservationData):
raise TypeError('data must be an ObservationData instance')
if spec_exposure is None:
spec_exposure = data.spec_exposure
else:
spec_exposure = float(spec_exposure)
if data.has_back:
back_counts = data.back_data.counts
back_errors = data.back_data.errors
back_poisson = data.back_poisson
back_area_scale = data.back_data.area_scale
back_back_scale = data.back_data.back_scale
if back_exposure is None:
back_exposure = data.back_exposure
else:
back_exposure = float(back_exposure)
else:
back_counts = back_errors = back_exposure = back_poisson = None
back_area_scale = back_back_scale = 1.0
spec_data = data.spec_data
resp_data = data.resp_data
simulation = self.simulate(
photon_egrid=resp_data.photon_egrid,
channel_emin=resp_data.channel_emin,
channel_emax=resp_data.channel_emax,
response_matrix=resp_data.sparse_matrix,
spec_exposure=spec_exposure,
spec_poisson=spec_data.poisson,
spec_errors=spec_data.errors,
back_counts=back_counts,
back_errors=back_errors,
back_exposure=back_exposure,
back_poisson=back_poisson,
spec_area_scale=spec_data.area_scale,
spec_back_scale=spec_data.back_scale,
back_area_scale=back_area_scale,
back_back_scale=back_back_scale,
quality=np.where(data.good_quality, 0, 1),
grouping=data.grouping,
channel=resp_data.channel,
channel_type=resp_data.channel_type,
response_sparse=resp_data.sparse,
params=params,
name=data.name,
seed=seed,
**kwargs,
)
simulation.set_erange(data.erange)
return simulation
[docs]
def simulate(
self,
photon_egrid: NDArray,
channel_emin: NDArray,
channel_emax: NDArray,
response_matrix: NDArray | sparray,
spec_exposure: float,
spec_poisson: bool = True,
spec_errors: NDArray | None = None,
back_counts: NDArray | None = None,
back_errors: NDArray | None = None,
back_exposure: float | None = None,
back_poisson: bool | None = None,
spec_area_scale: float | NDArray = 1.0,
spec_back_scale: float | NDArray = 1.0,
back_area_scale: float | NDArray = 1.0,
back_back_scale: float | NDArray = 1.0,
quality: NDArray | None = None,
grouping: NDArray | None = None,
channel: NDArray | None = None,
channel_type: str = 'Ch',
response_sparse: bool = False,
params: Sequence | Mapping | None = None,
name: str = 'simulation',
seed: int = 42,
**kwargs: dict,
) -> ObservationData:
"""Simulate observation data.
Parameters
----------
photon_egrid : ndarray
Photon energy grid in units of keV.
channel_emin : ndarray
Lower energy bounds of the detector channels.
channel_emax : ndarray
Upper energy bounds of the detector channels.
response_matrix : ndarray
Response matrix of the detector.
spec_exposure : float
Exposure time of the source.
spec_poisson : bool, optional
Whether the source spectrum is Poisson distributed. If false,
`spec_errors` must be provided. The default is True.
spec_errors : ndarray, optional
Errors of the source spectrum. Must be provided if `spec_poisson`
is False.
back_counts : ndarray, optional
Background counts in each channel.
back_errors : ndarray, optional
Errors of the background counts. Must be provided if `back_poisson`
is False.
back_exposure : float, optional
Exposure time of the background.
back_poisson : bool, optional
Whether the background spectrum is Poisson distributed. If false,
`back_errors` must be provided.
spec_area_scale : float or ndarray, optional
Area scale factor of the source. The default is 1.0.
spec_back_scale : float or ndarray, optional
Background scale factor of the source. The default is 1.0.
back_area_scale : float or ndarray, optional
Area scale factor of the background. The default is 1.0.
back_back_scale : float or ndarray, optional
Background scale factor of the background. The default is 1.0.
quality : ndarray, optional
Quality flags of the data.
grouping : ndarray, optional
Grouping flags of the data.
channel : ndarray, optional
Channel numbers.
channel_type : str, optional
Channel type. The default is 'Ch'.
response_sparse : bool, optional
Whether the response matrix is sparse. The default is False.
params : sequence or mapping, optional
Parameter sequence or mapping for the model.
name : str, optional
Name of the simulation data. The default is 'simulation'.
seed : int, optional
Random seed for simulation. The default is 42.
**kwargs : dict, optional
Additional keyword arguments passed to :class:`ObservationData`.
Returns
-------
ObservationData
Simulated observation data.
"""
if self.type != 'add':
msg = f'cannot simulate data from {self.type} type model "{self}"'
raise TypeError(msg)
if channel is None:
channel = np.arange(len(channel_emin)).astype(str)
rng = np.random.default_rng(int(seed))
resp_data = ResponseData(
photon_egrid=photon_egrid,
channel_emin=channel_emin,
channel_emax=channel_emax,
response_matrix=response_matrix,
channel=channel,
channel_type=channel_type,
sparse=response_sparse,
)
if not spec_poisson:
if spec_errors is None:
raise ValueError(
'spec_errors must be provided if spec_poisson is False'
)
else:
spec_errors = np.asarray(spec_errors, float)
if spec_errors.size != resp_data.channel_number:
raise ValueError(
f'spec_errors size ({np.size(spec_errors)}) must be '
f'channel number ({resp_data.channel_number})'
)
if not (
back_counts is None
and back_exposure is None
and back_poisson is None
and back_errors is None
):
if back_counts is None:
raise ValueError('back_counts must be also provided')
if back_exposure is None:
raise ValueError('back_exposure must be also provided')
if back_poisson is None:
raise ValueError('back_poisson must be also provided')
back_counts = np.asarray(back_counts, np.float64)
if back_counts.size != resp_data.channel_number:
raise ValueError(
f'back_counts size ({np.size(back_counts)}) must be '
f'channel number ({resp_data.channel_number})'
)
if spec_poisson or back_poisson:
if np.any(back_counts < 0):
warnings.warn(
'negative background counts is set to 0 for '
'Poisson counts simulation',
Warning,
stacklevel=2,
)
back_counts = np.clip(back_counts, 0, None)
if not back_poisson:
if back_errors is None:
raise ValueError(
'back_errors must be provided if back_poisson is False'
)
else:
back_errors = np.asarray(back_errors, float)
if back_errors.size != resp_data.channel_number:
raise ValueError(
f'back_errors size ({np.size(back_errors)}) must '
f'be channel number ({resp_data.channel_number})'
)
has_back = True
else:
has_back = False
if has_back:
if back_poisson:
back_sim = np.array(rng.poisson(back_counts), np.float64)
back_errors = np.sqrt(back_sim)
else:
# TODO: should fractional systematic errors be added?
back_sim = rng.normal(back_counts, back_errors)
back_data = SpectrumData(
counts=back_sim,
errors=back_errors,
poisson=back_poisson,
exposure=back_exposure,
quality=quality,
grouping=grouping,
area_scale=back_area_scale,
back_scale=back_back_scale,
)
else:
back_data = None
matrix = resp_data.sparse_matrix.T
folded_rate = matrix @ self.eval(resp_data.photon_egrid, params)
spec_counts = folded_rate * spec_exposure * spec_area_scale
if has_back:
back_ratio = (
spec_exposure * spec_area_scale * spec_back_scale
) / (back_exposure * back_area_scale * back_back_scale)
spec_counts += back_ratio * back_counts
if spec_poisson:
spec_sim = np.array(rng.poisson(spec_counts), np.float64)
spec_errors = np.sqrt(spec_sim)
else:
# TODO: should fractional systematic errors be added?
spec_sim = rng.normal(spec_counts, spec_errors)
spec_data = SpectrumData(
counts=spec_sim,
errors=spec_errors,
poisson=spec_poisson,
exposure=spec_exposure,
quality=quality,
grouping=grouping,
area_scale=spec_area_scale,
back_scale=spec_back_scale,
)
return ObservationData(
name=name,
erange=[resp_data.channel_emin[0], resp_data.channel_emax[-1]],
spec_data=spec_data,
resp_data=resp_data,
back_data=back_data,
**kwargs,
)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
fields = ['No.', 'Component', 'Parameter', 'Value', 'Bound', 'Prior']
info = self._model_info.info
tab_params = make_pretty_table(fields, info)
mtype = get_mtype_str(self)
return f'{mtype} Model: {self.name}\n{tab_params.get_string()}'
def _repr_html_(self) -> str:
fields = ['No.', 'Component', 'Parameter', 'Value', 'Bound', 'Prior']
info = self._model_info.info
tab_params = make_pretty_table(fields, info)
mtype = get_mtype_str(self)
return (
'<details open>'
f'<summary><b>{mtype} Model: {self.name}</b></summary>'
f'{tab_params.get_html_string(format=True)}'
'</details>'
)
def __delattr__(self, key: str):
if self.__initialized and hasattr(self, key):
raise AttributeError("can't delete attribute")
super().__delattr__(key)
def __setattr__(self, key: str, value: Any):
if self.__initialized:
raise AttributeError("can't set attribute")
super().__setattr__(key, value)
[docs]
class UniComponentModel(Model):
"""Model defined by a single additive or multiplicative component."""
_component: Component
def __init__(self, component: Component):
self._component = component
super().__init__(component._id, component._id, [component])
@property
def eval(self) -> ModelEval:
comp_id = self._component._id
_fn = self._component.eval
def fn(egrid: JAXArray, params: CompIDParamValMapping) -> JAXArray:
"""The model evaluation function"""
return _fn(egrid, params[comp_id])
return jax.jit(fn)
@property
def type(self) -> Literal['add', 'mul']:
return self._component.type
[docs]
class CompositeModel(Model):
"""Model defined by sum or product of two models."""
_operands: tuple[Model, Model]
_op: Callable[[JAXArray, JAXArray], JAXArray]
_op_symbol: Literal['+', '*']
_type: Literal['add', 'mul']
__additive_comps: tuple[Model, ...] | None = None
def __init__(self, lhs: Model, rhs: Model, op: Literal['+', '*']):
# check if the type of lhs and rhs are both model
if not (isinstance(lhs, Model) and isinstance(rhs, Model)):
raise TypeError(
f'unsupported operand types for {op}: '
f"'{type(lhs).__name__}' and '{type(rhs).__name__}'"
)
self._operands = (lhs, rhs)
self._op_symbol = op
# check if operand is legal for the operator
type1 = lhs.type
type2 = rhs.type
if op == '+':
if type1 != 'add':
raise TypeError(f'{lhs} is not an additive model')
if type2 != 'add':
raise TypeError(f'{rhs} is not an additive model')
self._type = 'add'
self._op = jnp.add
op_name = '{} + {}'
op_latex = '{{{}}} + {{{}}}'
elif op == '*':
if 'add' == type1 == type2:
raise TypeError(
f'unsupported model type for *: {lhs} (additive) '
f'and {rhs} (additive)'
)
if 'conv' == type1 or 'conv' == type2:
raise TypeError(
f'unsupported model type for *: {lhs} ({type1}) '
f'and {rhs} ({type2})'
)
self._type = 'add' if type1 == 'add' or type2 == 'add' else 'mul'
self._op = jnp.multiply
op_name = '{} * {}'
op_latex = r'{{{}}} \times {{{}}}'
else:
raise NotImplementedError(f'op {op}')
lhs_name = lhs._name
lhs_latex = lhs._latex
rhs_name = rhs._name
rhs_latex = rhs._latex
if op == '*':
if isinstance(lhs, CompositeModel) and lhs._op_symbol == '+':
lhs_name = f'({lhs_name})'
lhs_latex = rf'\left({lhs_latex}\right)'
if isinstance(rhs, CompositeModel) and rhs._op_symbol == '+':
rhs_name = f'({rhs_name})'
rhs_latex = rf'\left({rhs_latex}\right)'
comps = list(lhs._comps)
for c in rhs._comps:
if c not in comps:
comps.append(c)
super().__init__(
op_name.format(lhs_name, rhs_name),
op_latex.format(lhs_latex, rhs_latex),
comps,
)
@property
def eval(self) -> ModelEval:
op = self._op
lhs = self._operands[0].eval
rhs = self._operands[1].eval
def fn(egrid: JAXArray, params: CompIDParamValMapping) -> JAXArray:
"""The model evaluation function"""
return op(lhs(egrid, params), rhs(egrid, params))
return jax.jit(fn)
@property
def _additive_comps(self) -> tuple[Model, ...]:
if self.__additive_comps is not None:
return self.__additive_comps
if self.type != 'add':
comps = ()
else:
op = self._op_symbol
lhs = self._operands[0]
rhs = self._operands[1]
if op == '+': # add + add
comps = lhs._additive_comps + rhs._additive_comps
else:
if lhs.type == 'add': # add * mul
lhs_comps = lhs._additive_comps
if len(lhs_comps) >= 2:
comps = tuple(c * rhs for c in lhs_comps)
else:
comps = (self,)
else: # mul * add
rhs_comps = rhs._additive_comps
if len(rhs_comps) >= 2:
comps = tuple(lhs * c for c in rhs_comps)
else:
comps = (self,)
self.__additive_comps = comps
return self.__additive_comps
@property
def type(self) -> Literal['add', 'mul']:
return self._type
def _param_getter(name: str) -> Callable[[Component], Parameter]:
def getter(self: Component) -> Parameter:
"""Get parameter."""
return getattr(self, f'__{name}')
return getter
def _param_setter(name: str, idx: int) -> Callable[[Component, Any], None]:
def setter(self: Component, param: Any) -> None:
"""Set parameter."""
if param is not None and getattr(self, f'__{name}', None) is param:
return
cfg = self._config[idx]
if param is None:
# use default configuration
param = UniformParameter(
name=cfg.name,
default=cfg.default,
min=cfg.min,
max=cfg.max,
log=cfg.log,
fixed=cfg.fixed,
latex=cfg.latex,
)
elif isinstance(
param, float | int | jnp.number | jnp.ndarray | np.ndarray
):
# fixed to the given value
if jnp.shape(param) != ():
raise ValueError('scalar value is expected')
param = UniformParameter(
name=cfg.name,
default=param,
min=cfg.min,
max=cfg.max,
log=cfg.log,
fixed=True,
latex=cfg.latex,
)
elif isinstance(param, Parameter):
# given parameter instance
param = param
# set the latex format as specified in model config
param.latex = cfg.latex
elif isinstance(param, Mapping):
# given mapping
if not {'default', 'min', 'max'}.issubset(param.keys()):
raise ValueError(
f'{type(self).__name__}.{cfg.name} expected dict with keys'
f' "default", "min", "max", with optional "log" and'
f' "fixed", but got {param}'
)
param = UniformParameter(
name=cfg.name,
default=param['default'],
min=param['min'],
max=param['max'],
log=param.get('log', False),
fixed=param.get('fixed', False),
latex=cfg.latex,
)
elif isinstance(param, Sequence):
# given sequence
if len(param) not in {1, 3, 4}:
raise ValueError(
f'{type(self).__name__}.{cfg.name} expected sequence of '
'length 1, 3, or 4: '
'[default (float)], '
'[default (float), min (float), max (float)], or '
'[default (float), min (float), max (float), log (bool)], '
f'but got {param}'
)
if len(param) == 1:
(default,) = param
min_ = cfg.min
max_ = cfg.max
log = cfg.log
elif len(param) == 3:
default, min_, max_ = param
log = cfg.log
else:
default, min_, max_, log = param
param = UniformParameter(
name=cfg.name,
default=default,
min=min_,
max=max_,
log=log,
fixed=cfg.fixed,
latex=cfg.latex,
)
else:
# other input types are not supported yet
raise TypeError(
f'{type(self).__name__}.{cfg.name} got unsupported value '
f'{param} ({type(param).__name__})'
)
prev_param: Parameter | None = getattr(self, f'__{name}', None)
if prev_param is not None:
prev_param._tracker.remove(self._id, name)
setattr(self, f'__{name}', param)
param._tracker.append(self._id, name)
return setter
[docs]
class ParamConfig(NamedTuple):
"""Configuration of a uniform parameter for spectral component."""
name: str
"""Plain name of the parameter."""
latex: str
r""":math:`\LaTeX` of the parameter."""
unit: str
"""Physical unit."""
default: float
"""Default value of the parameter."""
min: float
"""Minimum value of the parameter."""
max: float
"""Maximum value of the parameter."""
log: bool = False
"""Whether the parameter is parameterized in a logarithmic scale."""
fixed: bool = False
"""Whether the parameter is fixed."""
[docs]
class Component(ABC, metaclass=ComponentMeta):
"""Base class to define a spectral component."""
_config: tuple[ParamConfig, ...]
_id: CompID
_args: tuple[str, ...] = () # extra args passed to subclass __init__
_kwargs: tuple[str, ...] = () # extra kwargs passed to subclass __init__
_staticmethod: tuple[str, ...] = () # methods need to be static
__initialized: bool = False
def __init__(self, params: dict, latex: str | None):
self._id = hex(id(self))[2:]
if latex is None:
latex = rf'\mathrm{{{self.__class__.__name__}}}'
self.latex = latex
self._name = self.__class__.__name__
# parse parameters from params, which is a dict of parameters
for cfg in self._config:
setattr(self, cfg.name, params[cfg.name])
self._param_names = tuple(cfg.name for cfg in self._config)
self.__initialized = True
@property
@abstractmethod
def eval(self) -> CompEval:
"""Get side-effect free component evaluation function."""
pass
@property
def name(self) -> str:
"""Component name."""
return self._name
@property
def latex(self) -> str:
r""":math:`\LaTeX` format of the component."""
return self._latex
@latex.setter
def latex(self, latex: str):
self._latex = str(latex)
@property
@abstractmethod
def type(self) -> Literal['add', 'mul', 'conv']:
"""Component type."""
pass
@property
def param_names(self) -> tuple[str, ...]:
"""Component's parameter names."""
return self._param_names
@property
@abstractmethod
def _config(self) -> tuple[ParamConfig, ...]:
"""Default configuration of parameters."""
pass
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
# TODO: show the component name, LaTeX, type and parameters
return self.name
def __getitem__(self, key: str) -> Parameter:
if key not in self.param_names:
raise KeyError(key)
return getattr(self, key)
def __setitem__(self, key: str, value: Any):
if key not in self.param_names:
raise KeyError(key)
setattr(self, key, value)
def __delattr__(self, key: str):
if self.__initialized and hasattr(self, key):
raise AttributeError("can't delete attribute")
super().__delattr__(key)
def __setattr__(self, key: str, value: Any):
if self.__initialized and not hasattr(self, key):
raise AttributeError("can't set attribute")
super().__setattr__(key, value)
[docs]
class AdditiveComponent(Component):
"""Prototype class to define an additive component."""
@property
def type(self) -> Literal['add']:
"""Component type is additive."""
return 'add'
[docs]
class MultiplicativeComponent(Component):
"""Prototype class to define a multiplicative component."""
@property
def type(self) -> Literal['mul']:
"""Component type is multiplicative."""
return 'mul'
[docs]
class AnalyticalIntegral(Component):
"""Prototype component to calculate model integral analytically."""
_integral_jit: CompEval | None = None
_staticmethod = ('integral',)
def __init__(self, params: dict, latex: str | None):
super().__init__(params, latex)
@property
def eval(self) -> CompEval:
if self._integral_jit is None:
self._integral_jit = jax.jit(self.integral)
return self._integral_jit
[docs]
@staticmethod
@abstractmethod
def integral(*args, **kwargs) -> JAXArray:
"""Calculate the model value over grid."""
pass
[docs]
class NumericalIntegral(Component):
"""Prototype component to calculate model integral numerically."""
_kwargs = ('method',)
_continuum_jit: CompEval | None = None
_staticmethod = ('continuum',)
def __init__(
self,
params: dict,
latex: str | None,
method: Literal['trapz', 'simpson'] | None = None,
):
self.method = 'trapz' if method is None else method
super().__init__(params, latex)
@property
def eval(self) -> CompEval:
if self._continuum_jit is None:
# _continuum is assumed to be a pure function, independent of self
self._continuum_jit = jax.jit(self.continuum)
return self._make_integral(self._continuum_jit)
def _make_integral(self, continuum: CompEval):
mtype = self.type
if self.method == 'trapz':
def fn(egrid: JAXArray, params: NameValMapping) -> JAXArray:
"""Numerical integration using trapezoidal rule."""
if mtype == 'add':
factor = 0.5 * (egrid[1:] - egrid[:-1])
else:
factor = 0.5
f_grid = continuum(egrid, params)
return factor * (f_grid[:-1] + f_grid[1:])
elif self.method == 'simpson':
def fn(egrid: JAXArray, params: NameValMapping) -> JAXArray:
"""Numerical integration using Simpson's 1/3 rule."""
if mtype == 'add':
factor = (egrid[1:] - egrid[:-1]) / 6.0
else:
factor = 1.0 / 6.0
e_mid = 0.5 * (egrid[:-1] + egrid[1:])
f_grid = continuum(egrid, params)
f_mid = continuum(e_mid, params)
return factor * (f_grid[:-1] + 4.0 * f_mid + f_grid[1:])
else:
raise NotImplementedError(f"integration method '{self.method}'")
return jax.jit(fn)
[docs]
@staticmethod
@abstractmethod
def continuum(*args, **kwargs) -> JAXArray:
"""Calculate the model value at the energy grid."""
pass
@property
def method(self) -> Literal['trapz', 'simpson']:
"""Numerical integration method."""
return self._method
@method.setter
def method(self, method: Literal['trapz', 'simpson']):
if method not in ('trapz', 'simpson'):
raise ValueError(
f"available integration methods are 'trapz' and 'simpson', "
f"but got '{method}'"
)
self._method = method
[docs]
class AnaIntAdditive(AnalyticalIntegral, AdditiveComponent):
"""Prototype additive component with integral expression defined."""
[docs]
@staticmethod
@abstractmethod
def integral(*args, **kwargs) -> JAXArray:
"""Calculate the photon flux integrated over the energy grid.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the model.
Returns
-------
jax.Array
The photon flux integrated over `egrid`, in units of ph cm⁻² s⁻¹.
"""
pass
[docs]
class NumIntAdditive(NumericalIntegral, AdditiveComponent):
"""Prototype additive component with continuum expression defined."""
[docs]
@staticmethod
@abstractmethod
def continuum(*args, **kwargs) -> JAXArray:
"""Calculate the photon flux at the energy grid.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the model.
Returns
-------
jax.Array
The photon flux at `egrid`, in units of ph cm⁻² s⁻¹ keV⁻¹.
"""
pass
[docs]
class AnaIntMultiplicative(AnalyticalIntegral, MultiplicativeComponent):
"""Prototype multiplicative component with integral expression defined."""
[docs]
@staticmethod
@abstractmethod
def integral(*args, **kwargs) -> JAXArray:
"""Calculate the average value of the model between the grid.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the model.
Returns
-------
jax.Array
The model value, dimensionless.
"""
pass
[docs]
class NumIntMultiplicative(NumericalIntegral, MultiplicativeComponent):
"""Prototype multiplicative component with continuum expression defined."""
[docs]
@staticmethod
@abstractmethod
def continuum(*args, **kwargs) -> JAXArray:
"""Calculate the model value at the energy grid.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the model.
Returns
-------
jax.Array
The model value at `egrid`, dimensionless.
"""
pass
[docs]
class ConvolutionModel(Model):
"""Model defined by a convolution component."""
def __init__(self, component: ConvolutionComponent):
self._component = component
super().__init__(component._id, component._id, [component])
def __call__(self, model: Model) -> ConvolvedModel:
if model.type not in self._component._supported:
accepted = [f"'{i}'" for i in self._component._supported]
raise TypeError(
f'{self.name} convolution model supports models with type: '
f"{', '.join(accepted)}; got '{model.type}' type model {model}"
)
return ConvolvedModel(self._component, model)
@property
def eval(self):
# this is not supposed to be called
raise RuntimeError('convolution model has no eval defined')
@property
def type(self) -> Literal['conv']:
"""Model type is convolution."""
return 'conv'
[docs]
class ConvolvedModel(Model):
"""Model created by convolution."""
_op: ConvolutionComponent
_model: Model
def __init__(self, op: ConvolutionComponent, model: Model):
self._op = op
self._model = model
name = f'{op._id}({model._name})'
latex = rf'{{{op._id}}}\left({model._latex}\right)'
comps = list(model._comps)
if op not in comps:
comps = comps + [op]
self.__additive_comps = None
super().__init__(name, latex, comps)
@property
def _additive_comps(self) -> tuple[Model, ...]:
if self.__additive_comps is not None:
return self.__additive_comps
if self.type != 'add':
comps = ()
else:
op = self._op
model_comps = self._model._additive_comps
if len(model_comps) >= 2:
comps = tuple(ConvolvedModel(op, m) for m in model_comps)
else:
comps = (self,)
self.__additive_comps = comps
return self.__additive_comps
@property
def eval(self) -> ModelEval:
comp_id = self._op._id
_fn = self._op.eval
_model_fn = self._model.eval
def fn(egrid: JAXArray, params: CompIDParamValMapping) -> JAXArray:
"""The convolved model evaluation function."""
return _fn(egrid, params[comp_id], lambda e: _model_fn(e, params))
return jax.jit(fn)
@property
def type(self) -> Literal['add', 'mul']:
return self._model.type
[docs]
class ConvolutionComponent(Component):
"""Prototype class to define a convolution component."""
_supported: frozenset = frozenset({'add', 'mul'})
_convolve_jit: ConvolveEval | None = None
_staticmethod = ('convolve',)
@property
def type(self) -> Literal['conv']:
return 'conv'
@property
def eval(self) -> ConvolveEval:
if self._convolve_jit is None:
self._convolve_jit = jax.jit(self.convolve, static_argnums=2)
return self._convolve_jit
[docs]
@staticmethod
@abstractmethod
def convolve(*args, **kwargs) -> JAXArray:
"""Convolve a model function.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the convolution model.
model_fn : callable
The model function to be convolved, which takes the energy grid as
input and returns the model value over the grid.
Returns
-------
jax.Array
The convolved model value over `egrid`.
"""
pass
[docs]
class PyComponent(Component):
"""Prototype component with pure Python expression defined."""
_kwargs: tuple[str, ...] = ('grad_method',)
def __init__(
self,
params: dict,
latex: str | None,
grad_method: Literal['central', 'forward'] | None,
):
self.grad_method = grad_method
super().__init__(params, latex)
@property
def grad_method(self) -> Literal['central', 'forward']:
"""Numerical differentiation method."""
return self._grad_method
@grad_method.setter
def grad_method(self, value: Literal['central', 'forward'] | None):
if value is None:
value: Literal['central'] = 'central'
if value not in {'central', 'forward'}:
raise ValueError(
f"supported methods are 'central' and 'forward', but got "
f"'{value}'"
)
self._grad_method = value
[docs]
class PyAnaInt(PyComponent, AnalyticalIntegral):
"""Prototype component with python integral expression defined."""
@property
def eval(self) -> CompEval:
if self._integral_jit is None:
integral_fn = self.integral
def eval_integral(egrid, params):
egrid = np.asarray(egrid)
params = {k: np.asarray(v) for k, v in params.items()}
return integral_fn(egrid, params)
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
shape_dtype = jax.ShapeDtypeStruct(
(egrid.size - 1,), egrid.dtype
)
return jax.pure_callback(
eval_integral,
shape_dtype,
egrid,
params,
vmap_method='sequential',
)
self._integral_jit = jax.jit(
define_fdjvp(jax.jit(integral), self.grad_method)
)
return self._integral_jit
[docs]
class PyNumInt(PyComponent, NumericalIntegral):
"""Prototype component with python continuum expression defined."""
_kwargs = ('method', 'grad_method')
def __init__(
self,
params: dict,
latex: str | None,
method: Literal['trapz', 'simpson'] | None,
grad_method: Literal['central', 'forward'] | None,
):
super().__init__(params, latex, grad_method)
self.method = 'trapz' if method is None else method
@property
def eval(self) -> CompEval:
if self._continuum_jit is None:
# continuum is assumed to be a pure function, independent of self
continuum_fn = self.continuum
def eval_continuum(egrid, params):
egrid = np.asarray(egrid)
params = {k: np.asarray(v) for k, v in params.items()}
return continuum_fn(egrid, params)
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
shape_dtype = jax.ShapeDtypeStruct(egrid.shape, egrid.dtype)
return jax.pure_callback(
eval_continuum,
shape_dtype,
egrid,
params,
vmap_method='sequential',
)
self._continuum_jit = jax.jit(
define_fdjvp(jax.jit(continuum), self.grad_method)
)
return self._make_integral(self._continuum_jit)
[docs]
def get_mtype_str(model: Model | CompiledModel) -> str:
"""Get the model type string."""
match model.type:
case 'add':
return 'Additive'
case 'mul':
return 'Multiplicative'
case 'conv':
return 'Convolution'
case _:
raise ValueError(f"unknown model type '{model.type}'")
[docs]
def get_model_info(
comps: Sequence[Component],
cid_to_name: CompIDStrMapping,
cid_to_latex: CompIDStrMapping,
) -> ModelInfo:
"""Get the model information.
Parameters
----------
comps : sequence of Component
The sequence of components.
cid_to_name : mapping
The mapping of component id to component name.
cid_to_latex : mapping
The mapping of component id to component LaTeX format.
Returns
-------
ModelInfo
The model parameter information dict.
"""
# get all parameter information
params_info: dict[ParamID, ParamInfo] = reduce(
lambda x, y: x | y,
[comp[name]._info for comp in comps for name in comp.param_names],
)
# mapping from pid to assigned component id and parameter name
comp_param: dict[ParamID, tuple[CompID, CompParamName] | None] = {
pid: tmp
for pid, info in params_info.items()
if (tmp := info.tracker.get_comp_param(cid_to_name.keys()))
}
# names of non-composite params with no component assigned (aux params)
aux_params = {
pid: info
for pid, info in params_info.items()
if (pid not in comp_param) and (not info.composite)
}
aux_names = [i.name for i in aux_params.values()]
aux_names = build_namespace(aux_names, prime=True)['namespace']
# name mapping from aux params id to name
name_mapping: ParamIDStrMapping = dict(
zip(aux_params.keys(), aux_names, strict=True)
)
# record name mapping of params directly assigned to a component
name_mapping |= {
pid: f'{cid_to_name[cid]}.{pname}'
for pid, (cid, pname) in comp_param.items()
}
# record the component LaTeX format of parameters
comp_latex_mapping = {
pid: cid_to_latex[cid] for pid, (cid, _) in comp_param.items()
}
comp_latex_mapping |= dict.fromkeys(aux_params, '')
# record the LaTeX format of aux parameters
aux_latex = [params_info[pid].latex for pid in aux_params]
aux_latex = build_namespace(aux_latex, latex=True, prime=True)['namespace']
latex_mapping = dict(zip(aux_params.keys(), aux_latex, strict=True))
# record the LaTeX format of component parameters from _config
# latex_mapping |= {
# comp[name]._id: comp._config[i].latex
# for comp in comps
# for (i, name) in enumerate(comp.param_names)
# }
# record the LaTeX format of component parameters
latex_mapping |= {
comp[name]._id: comp[name].latex
for comp in comps
for (i, name) in enumerate(comp.param_names)
}
# TODO: record aux params unit & unit consistency check
# record the unit of component parameters from _config
unit_mapping = {
comp[name]._id: comp._config[i].unit
for comp in comps
for (i, name) in enumerate(comp.param_names)
}
unit_mapping |= {pid: '' for pid, info in aux_params.items()}
# record whether the parameter is logarithmic
log = {pid: params_info[pid].log for pid in name_mapping}
# record the value of fixed parameters
fixed = {
pid: info.default
for pid, info in params_info.items()
if (pid in name_mapping) and info.fixed and (not info.integrate)
}
# integral operator
integrate = {
pid: info.integrate
for pid, info in params_info.items()
if info.integrate and not info.composite
}
# record the sample distribution of free parameters
sample: dict[ParamID, Distribution] = {
pid: info.dist for pid, info in params_info.items() if info.dist
}
# record the default value of free parameters
default: dict[ParamID, JAXFloat] = {
pid: params_info[pid].default for pid in sample
}
# ========== generate component parameter value getter function ===========
def factory1(
id_to_value: Callable[[ParamIDValMapping], JAXFloat],
) -> Callable[[ParamIDValMapping], JAXFloat]:
"""Fill in the default parameter values for fn."""
def fn(value_mapping: ParamIDValMapping) -> JAXFloat:
"""id_to_value with default values."""
return id_to_value(fixed_values | value_mapping)
return fn
def factory2(
pname_to_pid: dict[CompParamName, ParamID],
) -> Callable[[ParamIDValMapping], NameValMapping]:
"""Generate component parameter value getter."""
def fn(value_mapping: ParamIDValMapping) -> NameValMapping:
"""Component parameter value getter."""
return {
pname: pid_to_value[pid](value_mapping)
for pname, pid in pname_to_pid.items()
}
return fn
fixed_values = fixed
pid_to_value = {
pid: factory1(params_info[pid].id_to_value) for pid in comp_param
}
cid_to_params = {
comp._id: factory2({name: comp[name]._id for name in comp.param_names})
for comp in comps
}
# ========== generate component parameter value getter function ===========
# record non-interval params which has a component assigned
deterministic = {}
# record the setup of component parameters
setup: dict[CompParamName, (ParamID, ParamSetup)] = {}
# model_info is list of (No., Component, Parameter, Value, Bound, Prior)
model_info = []
sample_order = []
n = 0
for comp in comps: # iterate through components
cid = comp._id
cname = cid_to_name[cid]
for pname in comp.param_names: # iterate through parameters of comp
model_pname = f'{cname}.{pname}'
pid = comp[pname]._id
info = params_info[pid]
idx = ''
bound = info.bound
prior = info.prior
if (cid, pname) != comp_param[pid]: # param is linked to another
idx = ''
value = name_mapping[pid]
setup[model_pname] = (value, ParamSetup.Forwarded)
elif info.composite: # param is composite
mapping = dict(name_mapping)
del mapping[pid]
value = info.name(mapping)
if info.integrate: # param is composite interval
setup[model_pname] = (model_pname, ParamSetup.Integrated)
elif info.fixed: # param is composite, but fixed
setup[model_pname] = (model_pname, ParamSetup.Fixed)
else: # param is composite but free to vary, add it to determ
idx = '*'
setup[model_pname] = (model_pname, ParamSetup.Composite)
deterministic[pid] = pid_to_value[pid]
elif info.integrate: # param is interval
value = f'[{info.default[0]:.4g}, {info.default[1]:.4g}]'
setup[model_pname] = (model_pname, ParamSetup.Integrated)
else: # a single param
value = f'{info.default:.4g}'
if info.fixed:
setup[model_pname] = (model_pname, ParamSetup.Fixed)
else:
setup[model_pname] = (model_pname, ParamSetup.Free)
n += 1
idx = str(n)
sample_order.append(pid)
model_info.append((idx, cname, pname, value, bound, prior))
for pid, info in aux_params.items(): # iterate through aux params
idx = ''
bound = info.bound
prior = info.prior
if info.integrate: # aux param is interval
value = f'[{info.default[0]:.4g}, {info.default[1]:.4g}]'
else: # aux param is not interval, record its default value
value = f'{info.default:.4g}'
if not info.fixed: # aux param is free to vary
n += 1
idx = str(n)
sample_order.append(pid)
model_info.append((idx, '', name_mapping[pid], value, bound, prior))
if set(sample_order) != set(sample):
raise RuntimeError('sample_order and sample are not consistent')
sample = {pid: sample[pid] for pid in sample_order}
return ModelInfo(
info=tuple(model_info),
name=name_mapping,
latex=latex_mapping,
unit=unit_mapping,
sample=sample,
default=default,
deterministic=deterministic,
fixed=fixed,
log=log,
pid_to_comp_latex=comp_latex_mapping,
cid_to_params=cid_to_params,
cid_to_name=cid_to_name,
cid_to_latex=cid_to_latex,
integrate=integrate,
setup=setup,
)
[docs]
class ModelInfo(NamedTuple):
"""Model information."""
info: tuple[tuple[str, str, str, str, str, str], ...]
"""The model parameter information.
Each row contains (No., Component, Parameter, Value, Bound, Prior).
"""
name: ParamIDStrMapping
"""The mapping of parameter id to parameter name."""
latex: ParamIDStrMapping
r"""The mapping of component parameters name to :math:`\LaTeX` format."""
unit: ParamIDStrMapping
"""The mapping of component parameters name to physical unit."""
sample: dict[ParamID, Distribution]
"""The mapping of free parameter id to numpyro Distribution."""
default: ParamIDValMapping
"""The mapping of free parameter id to default value."""
deterministic: dict[ParamID, Callable[[ParamIDValMapping], JAXFloat]]
"""The mapping of deterministic parameters id to value getter."""
fixed: ParamIDValMapping
"""The mapping of fixed parameter id to fixed parameter value."""
log: dict[ParamID, bool]
"""The mapping of parameter id to logarithmic flag."""
pid_to_comp_latex: dict[ParamID, str]
"""The mapping of parameter id to component LaTeX format."""
cid_to_name: dict[CompID, CompName]
"""The mapping of component id to component name."""
cid_to_latex: dict[CompID, str]
"""The mapping of component id to component LaTeX format."""
cid_to_params: dict[CompID, Callable[[ParamIDValMapping], NameValMapping]]
"""The mapping of component id to parameter value getter function."""
integrate: dict[ParamID, IntegralFactory]
"""The mapping of interval parameter id to integral operator."""
setup: dict[CompParamName, tuple[ParamID, ParamSetup]]
"""The mapping of component parameter setup."""
[docs]
class ParamSetup(Enum):
"""Model parameter setup."""
Free = 0
"""Parameter is free to vary."""
Composite = 1
"""Parameter is composed of other free parameters."""
Forwarded = 2
"""Parameter is directly forwarded to another model parameter."""
Fixed = 3
"""Parameter is fixed to a value."""
Integrated = 4
"""Parameter is integrated out."""