Source code for elisa.models.conv

"""Convolution models."""

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp

from elisa.models.model import ConvolutionComponent, ParamConfig

if TYPE_CHECKING:
    from collections.abc import Callable

    from elisa.util.typing import ConvolveEval, JAXArray, NameValMapping

__all__ = ['EnFlux', 'PhFlux', 'ZAShift', 'ZMShift', 'VAShift', 'VMShift']


class NormConvolution(ConvolutionComponent):
    _args = ('emin', 'emax')
    _kwargs = ('ngrid', 'elog')
    _supported = frozenset({'add'})

    def __init__(
        self,
        emin: float | int,
        emax: float | int,
        params: dict,
        latex: str | None,
        ngrid: int | None,
        elog: bool | None,
    ):
        self._emin = float(emin)
        self.emax = emax

        self.ngrid = 1000 if ngrid is None else ngrid
        self.elog = True if elog is None else bool(elog)

        self._prev_config: tuple | None = None

        super().__init__(params, latex)

    @staticmethod
    @abstractmethod
    def convolve(
        egrid: JAXArray,
        params: NameValMapping,
        model_fn: Callable[[JAXArray], JAXArray],
        flux_egrid: JAXArray,
    ) -> JAXArray:
        """Convolve a model function.

        Parameters
        ----------
        egrid : ndarray
            Photon energy grid in units of keV.
        params : dict
            Parameter dict for the convolution model.
        model_fn : callable
            The model function to be convolved, which takes the energy grid as
            input and returns the model flux over the grid.
        flux_egrid : ndarray
            Photon energy grid used to calculate flux in units of keV.

        Returns
        -------
        value : ndarray
            The re-normalized model over `egrid`, in units of
            ph cm⁻² s⁻¹ keV⁻¹.
        """
        pass

    @property
    def eval(self) -> ConvolveEval:
        if self._prev_config == (self.emin, self.emax, self.ngrid, self.elog):
            return self._convolve_jit

        else:
            if self.elog:
                flux_egrid = jnp.geomspace(self.emin, self.emax, self.ngrid)
            else:
                flux_egrid = jnp.linspace(self.emin, self.emax, self.ngrid)

            fn = self.convolve

            def convolve(
                egrid: JAXArray,
                params: NameValMapping,
                model_fn: Callable[[JAXArray], JAXArray],
            ) -> JAXArray:
                # TODO: egrid can be reused to reduce computation
                return fn(egrid, params, model_fn, flux_egrid)

            self._prev_config = (self.emin, self.emax, self.ngrid, self.elog)
            self._convolve_jit = jax.jit(convolve, static_argnums=2)

            return self._convolve_jit

    @property
    def emin(self) -> float:
        """Minimum value of photon energy grid"""
        return self._emin

    @emin.setter
    def emin(self, value: float | int):
        value = float(value)
        if value >= self.emax:
            raise ValueError('emin must be less than emax')
        self._emin = value

    @property
    def emax(self) -> float:
        """Maximum value of photon energy grid"""
        return self._emax

    @emax.setter
    def emax(self, value: float | int):
        value = float(value)
        if value <= self._emin:
            raise ValueError('emax must be larger than emin')
        self._emax = float(value)

    @property
    def ngrid(self) -> int:
        """Photon energy grid number."""
        return self._ngrid

    @ngrid.setter
    def ngrid(self, value: int):
        self._ngrid = int(value)

    @property
    def elog(self) -> bool:
        """Whether to use logarithmically regular energy grids."""
        return self._elog

    @elog.setter
    def elog(self, value: bool):
        self._elog = bool(value)


[docs] class PhFlux(NormConvolution): r"""Normalize an additive model by photon flux between `emin` and `emax`. .. math:: N'(E) = \mathcal{F}_\mathrm{ph} \left[\int_{E_\mathrm{min}}^{E_\mathrm{max}} N(E) \, dE\right]^{-1} N(E) .. warning:: The normalization of one of the additive components **must** be fixed to a positive value. .. warning:: The flux is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin : float or int Minimum energy of the band to calculate the flux, in units of keV. emax : float or int Maximum energy of the band to calculate the flux, in units of keV. F : Parameter, optional Photon flux :math:`\mathcal{F}_\mathrm{ph}`, in units of ph cm⁻² s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. ngrid : int, optional The energy grid number to use. The default is 1000. elog : bool, optional Whether to use logarithmically regular energy grids. The default is True. """ _config = ( ParamConfig( 'F', r'\mathcal{F}_\mathrm{ph}', 'ph cm^-2 s^-1', 1.0, 0.01, 1e10 ), )
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], flux_egrid: JAXArray, ) -> JAXArray: F = params['F'] mflux = jnp.sum(model_fn(flux_egrid)) flux = model_fn(egrid) return F / mflux * flux
[docs] class EnFlux(NormConvolution): r"""Normalize an additive model by energy flux between `emin` and `emax`. .. math:: N'(E) = \mathcal{F}_\mathrm{en} \left[\int_{E_\mathrm{min}}^{E_\mathrm{max}} EN(E)\,dE\right]^{-1} N(E) .. warning:: The normalization of one of the additive components **must** be fixed to a positive value. .. warning:: The flux is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin : float or int Minimum energy of the band to calculate the flux, in units of keV. emax : float or int Maximum energy of the band to calculate the flux, in units of keV. F : Parameter, optional Energy flux :math:`\mathcal{F}_\mathrm{en}`, in units of erg cm⁻² s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. ngrid : int, optional The energy grid number to use. The default is 1000. elog : bool, optional Whether to use logarithmically regular energy grids. The default is True. """ _config = ( ParamConfig( 'F', r'\mathcal{F}_\mathrm{en}', 'erg cm^-2 s^-1', 1e-12, 1e-30, 1e30, log=True, ), )
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], flux_egrid: JAXArray, ) -> JAXArray: F = params['F'] keV_to_erg = 1.602176634e-9 mid = jnp.sqrt(flux_egrid[:-1] * flux_egrid[1:]) _flux = model_fn(flux_egrid) mflux = jnp.sum(keV_to_erg * mid * _flux) flux = model_fn(egrid) return F / mflux * flux
[docs] class ZAShift(ConvolutionComponent): r"""Redshifts an additive model. Consider a source with an emission area of radius :math:`R` at redshift :math:`z`. Given the flux function :math:`N(E)` [ph s⁻¹ cm⁻² keV⁻¹] at the radius :math:`R`, the observed number of photons :math:`n` between the energy range :math:`e_1` [keV] and :math:`e_2` [keV] during an exposure time of :math:`\Delta t` [s] is calculated as follows: .. math:: n &= \frac{R^2}{{D_\mathrm{c}}^2} \frac{\Delta t}{1+z} \int_{e_1(1+z)}^{e_2(1+z)} N(E) \, \mathrm{d}E \\\\ &= \frac{R^2}{{D_\mathrm{c}}^2} \frac{\Delta t}{1+z} \int_{E_1}^{E_2} N(E) \, \mathrm{d}E, where :math:`E_1 = e_1 (1+z)` [keV], :math:`E_2 = e_2 (1+z)` [keV] and :math:`D_\mathrm{c}` is the comoving distance of the source at redshift :math:`z`. Note that the :math:`\frac{R^2}{{D_\mathrm{c}}^2}` factor is absorbed into the normalization of :math:`N(E)` in practice. Parameters ---------- z : Parameter, optional Redshift :math:`z`, dimensionless. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. """ _supported = frozenset({'add'}) _config = (ParamConfig('z', 'z', '', 0.0, -0.999, 15.0, fixed=True),)
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], ) -> JAXArray: factor = 1.0 + params['z'] return model_fn(egrid * factor) / factor
[docs] class ZMShift(ConvolutionComponent): r"""Redshifts a multiplicative model. Consider a source at redshift :math:`z`. Given the dimensionless model function :math:`M(E)`, the observed value between the energy range :math:`e_1` [keV] and :math:`e_2` [keV] is calculated as follows: .. math:: m &= \frac{1}{(e_2 - e_1)(1+z)} \int_{e_1(1+z)}^{e_2(1+z)} M(E) \, \mathrm{d}E \\\\ &= \frac{1}{E_2 - E_1} \int_{E_1}^{E_2} M(E) \, \mathrm{d}E, where :math:`E_1 = e_1 (1+z)` [keV] and :math:`E_2 = e_2 (1+z)` [keV]. Parameters ---------- z : Parameter, optional Redshift :math:`z`, dimensionless. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. """ _supported = frozenset({'mul'}) _config = (ParamConfig('z', 'z', '', 0.0, -0.999, 15.0, fixed=True),)
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], ) -> JAXArray: factor = 1.0 + params['z'] return model_fn(egrid * factor)
[docs] class VAShift(ConvolutionComponent): r"""Velocity shifts an additive model. Consider a source with an emission area of radius :math:`R`, moving with speed :math:`v` along line of sight. Given the flux function :math:`N(E)` [ph s⁻¹ cm⁻² keV⁻¹] at the radius :math:`R`, the observed number of photons :math:`n` between the energy range :math:`e_1` [keV] and :math:`e_2` [keV] during an exposure time of :math:`\Delta t` [s] is calculated as follows: .. math:: n &= \Delta t \int_{fe_1}^{fe_2} N(E) \, \mathrm{d}E \\\\ &= \Delta t \int_{E_1}^{E_2} N(E) \, \mathrm{d}E, where :math:`E_1 = f e_1` [keV], :math:`E_2 = f e_2` [keV], and :math:`f = 1 - v/c`. Parameters ---------- v : Parameter, optional Velocity :math:`v`, in units of km s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. """ _supported = frozenset({'add'}) _config = (ParamConfig('v', 'v', 'km s^-1', 0.0, -1e4, 1e4, fixed=True),)
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], ) -> JAXArray: v = params['v'] # unit: km/s c = 299792.458 # unit: km/s f = 1.0 - v / c return model_fn(egrid * f)
[docs] class VMShift(ConvolutionComponent): r"""Velocity shifts a multiplicative model. Consider a source moving with speed :math:`v` along line of sight. Given the dimensionless model function :math:`M(E)`, the observed value between the energy range :math:`e_1` [keV] and :math:`e_2` [keV] is calculated as follows: .. math:: m &= \frac{1}{f (e_2 - e_1)} \int_{f e_1}^{f e_2} M(E) \, \mathrm{d}E \\\\ &= \frac{1}{E_2 - E_1} \int_{E_1}^{E_2} M(E) \, \mathrm{d}E, where :math:`E_1 = f e_1` [keV], :math:`E_2 = f e_2` [keV], and :math:`f = 1 - v/c`. Parameters ---------- v : Parameter, optional Velocity :math:`v`, in units of km s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. """ _supported = frozenset({'mul'}) _config = (ParamConfig('v', 'v', 'km s^-1', 0.0, -1e4, 1e4, fixed=True),)
[docs] @staticmethod def convolve( egrid: JAXArray, params: NameValMapping, model_fn: Callable[[JAXArray], JAXArray], ) -> JAXArray: v = params['v'] # unit: km/s c = 299792.458 # unit: km/s f = 1.0 - v / c return model_fn(egrid * f)