"""XSPEC model library API."""
from __future__ import annotations
import keyword
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
import numpy as np
from elisa.models.model import (
Component,
ComponentMeta,
ConvolutionModel,
ConvolvedModel,
ParamConfig,
)
from elisa.util.misc import define_fdjvp
try:
import xspex as _xx
from xspex import (
abund,
abund_file,
chatter,
clear_mstr,
clear_xflt,
cosmo,
list_models,
mstr,
xflt,
xsect,
xspec_version,
)
from xspex._xspec.types import (
XspecModelType as _XspecModelType,
)
__all__ = [
'xspec_version',
'list_models',
'abund',
'abund_file',
'xsect',
'cosmo',
'mstr',
'clear_mstr',
'xflt',
'clear_xflt',
'chatter',
*_xx.list_models(),
]
_HAS_XSPEC = True
except ImportError as e:
__all__ = []
_HAS_XSPEC = False
warnings.warn(f'XSPEC model library is not available: {e}', ImportWarning)
if TYPE_CHECKING:
from typing import Literal
from xspex._xspec.types import (
XspecParam as XspecParamInfo,
)
from elisa.models.model import Model, UniComponentModel
from elisa.util.typing import (
Callable,
CompEval,
CompIDParamValMapping,
JAXArray,
ModelEval,
ParamNameValMapping,
)
XspecConvolveEval = Callable[
[JAXArray, ParamNameValMapping, JAXArray],
JAXArray,
]
class XspecComponentMeta(ComponentMeta):
def __call__(
cls, *args, **kwargs
) -> UniComponentModel | XspecConvolutionModel:
if issubclass(cls, XspecConvolution):
component = super(ComponentMeta, cls).__call__(*args, **kwargs)
return XspecConvolutionModel(component)
else:
return super().__call__(*args, **kwargs)
class XspecComponent(Component, metaclass=XspecComponentMeta):
"""Xspec model wrapper."""
_kwargs: tuple[str, ...] = ('grad_method', 'spec_num')
_eval: CompEval | None = None
def __init__(
self,
params: dict,
latex: str | None,
grad_method: Literal['central', 'forward'] | None,
spec_num: int | None,
):
self.grad_method = grad_method
if spec_num is None:
spec_num = 1
self._spec_num = int(spec_num)
super().__init__(params, latex)
@property
def grad_method(self) -> Literal['central', 'forward']:
"""Numerical differentiation method."""
return self._grad_method
@grad_method.setter
def grad_method(self, value: Literal['central', 'forward'] | None):
if value is None:
value = 'central'
if value not in {'central', 'forward'}:
raise ValueError(
f"supported methods are 'central' and 'forward', but got "
f"'{value}'"
)
self._grad_method = value
@property
def spec_num(self) -> int:
"""Spectrum number."""
return self._spec_num
class XspecAdditive(XspecComponent):
@property
def type(self) -> Literal['add']:
return 'add'
@property
def eval(self) -> CompEval:
if self._eval is not None:
return self._eval
_integral = jax.jit(define_fdjvp(self._integral, self.grad_method))
def integral(egrid, params):
return params.pop('norm') * _integral(egrid, params)
self._eval = jax.jit(integral)
return self._eval
@property
@abstractmethod
def _integral(self) -> CompEval:
pass
class XspecMultiplicative(XspecComponent):
@property
def type(self) -> Literal['mul']:
return 'mul'
@property
def eval(self) -> CompEval:
if self._eval is not None:
return self._eval
self._eval = jax.jit(define_fdjvp(self._integral, self.grad_method))
return self._eval
@property
@abstractmethod
def _integral(self) -> CompEval:
pass
class XspecConvolutionModel(ConvolutionModel):
def __call__(self, model: Model) -> XspecConvolvedModel:
if model.type not in self._component._supported:
accepted = [f"'{i}'" for i in self._component._supported]
raise TypeError(
f'{self.name} convolution model supports models with type: '
f"{', '.join(accepted)}; got '{model.type}' type model {model}"
)
return XspecConvolvedModel(self._component, model)
class XspecConvolvedModel(ConvolvedModel):
_op: XspecConvolution
@property
def eval(self) -> ModelEval:
model = self._model.eval
comp_id = self._op._id
convolve = self._op.eval
elow = self._op._low_energy
nlow = self._op._low_ngrid
loglow = self._op._low_log
ehigh = self._op._high_energy
nhigh = self._op._high_ngrid
loghigh = self._op._high_log
def extend_low(egrid):
if egrid[0] < elow:
raise RuntimeError(
f'for Xspec convolution model {self}, the lower limit '
f'of the energy extension ({elow}) must be less than '
f'the minimum energy grid ({egrid[0]})'
)
if loglow:
low_extension = np.geomspace(elow, egrid[0], nlow + 1)[:-1]
else:
low_extension = np.linspace(elow, egrid[0], nlow + 1)[:-1]
return np.concatenate((low_extension, egrid)).astype(egrid.dtype)
def extend_high(egrid):
if egrid[-1] > ehigh:
raise RuntimeError(
f'for Xspec convolution model {self}, the upper limit '
f'of the energy extension ({ehigh}) must be greater than '
f'the maximum energy grid ({egrid[-1]})'
)
if loghigh:
high_extension = np.geomspace(egrid[-1], ehigh, nhigh + 1)[1:]
else:
high_extension = np.linspace(egrid[-1], ehigh, nhigh + 1)[1:]
return np.concatenate((egrid, high_extension)).astype(egrid.dtype)
def fn(egrid: JAXArray, params: CompIDParamValMapping) -> JAXArray:
"""The convolved model evaluation function."""
rtype = jax.ShapeDtypeStruct((egrid.size + nlow,), egrid.dtype)
egrid = jax.pure_callback(extend_low, rtype, egrid)
rtype = jax.ShapeDtypeStruct((egrid.size + nhigh,), egrid.dtype)
egrid = jax.pure_callback(extend_high, rtype, egrid)
conv_params = params[comp_id]
flux = model(egrid, params)
result = convolve(egrid, conv_params, flux)
return result[nlow:-nhigh]
fn = define_fdjvp(jax.jit(fn), self._op.grad_method)
return jax.jit(fn)
class XspecConvolution(XspecComponent):
_supported: frozenset[Literal['add', 'mul']]
_convolve_jit = None
_kwargs = (
'low_energy',
'low_ngrid',
'low_log',
'high_energy',
'high_ngrid',
'high_log',
'grad_method',
'spec_num',
)
def __init__(
self,
params: dict,
latex: str | None,
low_energy: float | int | None,
low_ngrid: int | None,
low_log: bool | None,
high_energy: float | int | None,
high_ngrid: int | None,
high_log: bool | None,
grad_method: Literal['central', 'forward'] | None,
spec_num: int | None,
):
self.grad_method = grad_method
self._spec_num = spec_num
if spec_num is None:
spec_num = 1
else:
spec_num = int(spec_num)
self._spec_num = spec_num
if low_energy is None:
low_energy = 0.01
self._low_energy = float(low_energy)
if low_ngrid is None:
low_ngrid = 100
self._low_ngrid = int(low_ngrid)
if low_log is None:
low_log = True
self._low_log = bool(low_log)
if high_energy is None:
high_energy = 100.0
self._high_energy = float(high_energy)
if high_ngrid is None:
high_ngrid = 100
self._high_ngrid = int(high_ngrid)
if high_log is None:
high_log = True
self._high_log = bool(high_log)
super().__init__(params, latex, grad_method, spec_num)
@property
def type(self) -> Literal['conv']:
return 'conv'
@property
def eval(self) -> XspecConvolveEval:
if self._convolve_jit is None:
self._convolve_jit = jax.jit(self._convolve)
return self._convolve_jit
@property
@abstractmethod
def _convolve(self) -> XspecConvolveEval:
pass
_XSPEC_MODEL_TEMPLATE_ADD = '''
class {name}(XspecAdditive):
"""Xspec additive model `{name} <{link}>`_: {desc}."""
_config = (
{params_config},
)
@property
def _integral(self):
spec_num = self.spec_num
params_names = [p.name for p in self._config if p.name != 'norm']
def integral(egrid, params):
params = [params[p] for p in params_names]
if params:
params = jnp.stack(params)
else:
params = jnp.empty(0)
return {name}(params, egrid, spec_num)
return integral
'''
_XSPEC_MODEL_TEMPLATE_MUL = '''
class {name}(XspecMultiplicative):
"""Xspec multiplicative model `{name} <{link}>`_: {desc}."""
_config = (
{params_config},
)
@property
def _integral(self):
spec_num = self.spec_num
params_names = [p.name for p in self._config]
def integral(egrid, params):
params = [params[p] for p in params_names]
if params:
params = jnp.stack(params)
else:
params = jnp.empty(0)
return {name}(params, egrid, spec_num)
return integral
'''
_XSPEC_MODEL_TEMPLATE_CON = '''
class {name}(XspecConvolution):
"""Xspec convolution model `{name} <{link}>`_: {desc}."""
_supported = frozenset(['{supported}'])
_config = (
{params_config},
)
@property
def _convolve(self):
spec_num = self.spec_num
params_names = [p.name for p in self._config]
def convolve(egrid, params, flux):
params = [params[p] for p in params_names]
if params:
params = jnp.stack(params)
else:
params = jnp.empty(0)
return {name}(params, egrid, flux, spec_num)
return convolve
'''
_XSPEC_MODEL_PARAM_CONFIG_TEMPLATE = (
r'ParamConfig(name="{name}", latex=r"\mathrm{{{name}}}", unit="{unit}", '
'default={default}, min={min}, max={max}, fixed={fixed})'
)
def _generate_xspec_models():
"""Generate Xspec model classes."""
models = {}
if not _HAS_XSPEC:
return models
# Convolution models that should be applied for multiplicative models
conv_for_mul = ('partcov', 'vmshift', 'zmshift')
# For compile model function
env = {
'ParamConfig': ParamConfig,
'XspecAdditive': XspecAdditive,
'XspecMultiplicative': XspecMultiplicative,
'XspecConvolution': XspecConvolution,
'jnp': jnp,
}
# For additive models
norm_config = _XSPEC_MODEL_PARAM_CONFIG_TEMPLATE.format(
name='norm',
latex=r'\mathrm{{norm}}',
unit='',
default=1.0,
min=1e-10,
max=1e10,
fixed=False,
)
def gen_param_config(param_info: XspecParamInfo) -> str:
"""Generate parameter configuration string given parameter info."""
name = param_info.name
default = param_info.default
unit = param_info.unit or ''
pmin = param_info.min
pmax = param_info.max
# smaug model's params names have dots
name = name.replace('.', '_')
# avoid keyword conflict
if keyword.iskeyword(name):
name = f'{name}_'
# min and max is None for switch and scale type parameters
if pmin is None:
pmin = -1
if pmax is None:
pmax = 1e10
return _XSPEC_MODEL_PARAM_CONFIG_TEMPLATE.format(
name=name,
unit=unit,
default=default,
min=pmin,
max=pmax,
fixed=param_info.fixed,
)
def make_xspec_model(name: str):
"""Make Xspec model class."""
fn, model_info = _xx.get_model(name)
vars_map = {
'name': name,
'desc': model_info.desc,
'link': model_info.link,
}
params_config = list(map(gen_param_config, model_info.parameters))
model_type = model_info.type
if model_type == _XspecModelType.Add:
params_config.append(norm_config)
template = _XSPEC_MODEL_TEMPLATE_ADD
elif model_type == _XspecModelType.Mul:
template = _XSPEC_MODEL_TEMPLATE_MUL
elif model_type == _XspecModelType.Con:
vars_map['supported'] = 'mul' if name in conv_for_mul else 'add'
template = _XSPEC_MODEL_TEMPLATE_CON
else:
raise ValueError(f'Unsupported model type: {model_type.name}')
vars_map['params_config'] = ',\n '.join(params_config)
model_code = template.format_map(vars_map)
exec(model_code, env | {name: fn}, models)
list(map(make_xspec_model, _xx.list_models()))
for m in models.values():
m.__module__ = __name__
return models
locals().update(_generate_xspec_models())