Source code for elisa.infer.results

"""Subsequent analysis of maximum likelihood or Bayesian fit."""

from __future__ import annotations

import bz2
import gzip
import lzma
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, NamedTuple

import arviz as az
import astropy.units as u
import dill
import jax
import jax.numpy as jnp
import numpy as np
import scipy.stats as stats
from arviz import InferenceData
from astropy.cosmology import Planck18
from iminuit import Minuit
from iminuit.util import Matrix as CovarMatrix
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec

from elisa.infer.helper import check_params
from elisa.plot.plotter import MLEResultPlotter, PosteriorResultPlotter
from elisa.util.config import get_parallel_number
from elisa.util.misc import make_pretty_table

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Sequence
    from typing import Any, Literal

    from arviz.stats.stats_utils import ELPDData
    from astropy.cosmology import LambdaCDM
    from astropy.units import Quantity as Q
    from iminuit.util import FMin
    from xarray import DataArray

    from elisa.infer.helper import Helper
    from elisa.plot.plotter import Plotter
    from elisa.util.typing import JAXArray


[docs] class FitResult(ABC): """Fit result.""" _helper: Helper _plotter: Plotter | None _flux_fn: Callable _lumin_fn: Callable _eiso_fn: Callable _n_parallel: int | None def __init__(self, helper: Helper): self._helper = helper self._n_parallel = None models = helper.model ne = {name: model.ne for name, model in models.items()} ene = {name: model.ene for name, model in models.items()} def _flux( egrid: JAXArray, params: dict[str, JAXArray], energy: bool, comps: bool, ) -> dict[str, JAXArray] | dict[str, dict[str, JAXArray]]: """Calculate flux.""" if energy: fns = ene else: fns = ne de = jnp.diff(egrid) flux = {} for name, fn in fns.items(): f = fn(egrid, params, comps) if comps: flux[name] = jax.tree.map( lambda v: jnp.sum(v * de, axis=-1), f ) else: flux[name] = jnp.sum(f * de, axis=-1) return flux self._flux_fn = jax.jit(_flux, static_argnums=(2, 3)) # some data has the same model, record the unique models to # avoid redundant computation in flux, lumin, and eiso mapping = {} for data_name, model in models.items(): if model not in mapping: mapping[model] = [data_name] else: mapping[model].append(data_name) self._model_mapping = {d: v[0] for v in mapping.values() for d in v} @abstractmethod def __repr__(self): pass @abstractmethod def _repr_html_(self): pass @property @abstractmethod def plot(self) -> Plotter: """Result plotter.""" pass
[docs] def summary(self, file=None) -> None: """Print the summary of fit result. Parameters ---------- file: file-like An object with a ``write(string)`` method. This is passed to :py:func:`print`. """ print(repr(self), file=file)
[docs] @abstractmethod def flux( self, *args, **kwargs ) -> dict[str, jax.Array] | dict[str, dict[str, jax.Array]]: pass
[docs] @abstractmethod def lumin( self, *args, **kwargs ) -> dict[str, jax.Array] | dict[str, dict[str, jax.Array]]: pass
[docs] @abstractmethod def eiso( self, *args, **kwargs ) -> dict[str, jax.Array] | dict[str, dict[str, jax.Array]]: pass
@property def ndata(self) -> dict[str, int]: """Data points number.""" return self._helper.ndata @property def dof(self) -> int: """Degree of freedom.""" return self._helper.dof @property @abstractmethod def gof(self) -> dict[str, float]: """Goodness of fit p-value.""" pass @property @abstractmethod def _params_dist(self) -> dict[str, JAXArray]: pass
[docs] def save( self, path: str, compress: Literal['gzip', 'bz2', 'lzma'] = 'gzip', **kwargs: dict, ) -> None: """Save the fit result to a file. Parameters ---------- path : str The file path to save fit result. compress : {'gzip', 'bz2', 'lzma'} The compression algorithm to use. **kwargs : dict Extra parameters passed to :py:func:`gzip.open`, :py:func:`bz2.open`, or :py:func:`lzma.open`. """ if compress == 'gzip': open_ = gzip.open elif compress == 'bz2': open_ = bz2.open elif compress == 'lzma': open_ = lzma.open else: raise ValueError(f'unsupported compression algorithm {compress}') with open_(path, 'wb', **kwargs) as f: dill.dump(self, f)
[docs] @staticmethod def load( path: str, decompress: Literal['gzip', 'bz2', 'lzma'] = 'gzip', ) -> FitResult: """Load a previously saved fit result. Parameters ---------- path : str The file path of previously saved fit result. decompress : {'gzip', 'bz2', 'lzma'} The decompression algorithm used to load the fit result. Returns ------- FitResult The loaded fit result. """ if decompress == 'gzip': open_ = gzip.open elif decompress == 'bz2': open_ = bz2.open elif decompress == 'lzma': open_ = lzma.open else: raise ValueError( f'unsupported decompression algorithm {decompress}' ) with open_(path, 'rb') as f: return dill.load(f)
def _to_unit_cl(self, cl: int | float): """Convert cl into unit.""" if cl <= 0: raise ValueError('cl must be positve') elif cl < 1: return cl else: return 1.0 - 2.0 * stats.norm.sf(cl) def _check_fn(self, fn: dict[str, Callable] | None): """Check user provided function.""" if fn is None: return {} else: helper = self._helper fn = {str(k): v for k, v in fn.items()} if fn.keys() & helper.params_default.keys(): raise ValueError( 'names of fn must not overlap with model parameters' ) fn_checked = {} if not isinstance(fn, dict): raise TypeError('fn must be dict of functions') for k, v in fn.items(): msg = ( f"fn['{k}'] must be a function of these parameters: " f'{", ".join(helper.params_names["all"])}' ) if not callable(v): raise TypeError(msg) else: try: jitted = jax.jit(v) fn_result = jitted(helper.params_default) fn_checked[str(k)] = jitted except Exception as e: raise ValueError(msg) from e if np.shape(fn_result) != (): raise ValueError( f"fn['{k}'] must return a scalar value" ) return fn_checked
[docs] class MLEResult(FitResult): """Result of maximum likelihood fit.""" _plotter: MLEResultPlotter | None = None def __init__(self, minuit: Minuit, helper: Helper): super().__init__(helper) self._minuit = minuit self._mle_unconstr = jnp.array(minuit.values, float) mle, covar = helper.get_mle(self._mle_unconstr) if np.allclose(covar, covar.T): try: np.linalg.cholesky(covar) pos_def = True except np.linalg.LinAlgError: pos_def = False else: pos_def = False if not pos_def and minuit.covariance is not None: covar_unconstr = jnp.array(minuit.covariance, float) covar = helper.params_covar(self._mle_unconstr, covar_unconstr) var2pos = dict( zip(helper.params_names['all'], range(len(mle)), strict=True) ) self._covar = CovarMatrix(var2pos) self._covar[:] = covar err = jnp.sqrt(jnp.diagonal(covar)) # MLE of model params in constrained space self._mle = dict( zip( helper.params_names['all'], zip(mle, err, strict=True), strict=True, ) ) # model deviance at MLE self._deviance = jax.jit(helper.deviance)(self._mle_unconstr) # model values at MLE sites = jax.jit(helper.get_sites)(self._mle_unconstr) self._model_values = sites['models'] # model comparison statistics k = self._helper.nparam n = self._helper.ndata['total'] stat = self._deviance['total'] self._aic = float(stat + k * 2 * (1 + (k + 1) / (n - k - 1))) self._bic = float(stat + k * np.log(n)) # parametric bootstrap result self._boot: BootstrapResult | None = None def __repr__(self): tabs = self._tabs() return ( f'Parameters\n{tabs["params"]}\n\n' f'Fit Statistics\n{tabs["stat"]}\n\n' f'Information Criterion\n{tabs["ic"]}\n\n' f'Fit Status\n{self.status}\n' ) def _repr_html_(self): """The repr in Jupyter notebook environment.""" tabs = self._tabs() params_tab = tabs['params'].get_html_string(format=True) stat_tab = tabs['stat'].get_html_string(format=True) ic_tab = tabs['ic'].get_html_string(format=True) status_tab = self.status._repr_html_() return ( '<details open><summary><b>MLE Result</b></summary>' '<details open style="padding-left: 1em">' f'<summary><b>Parameters</b></summary>{params_tab}</details>' '<details open style="padding-left: 1em">' f'<summary><b>Fit Statistics</b></summary>{stat_tab}</details>' f'<details open style="padding-left: 1em">' '<summary><b>Information Criterion</b></summary>' f'{ic_tab}</details>' '<details style="padding-left: 1em">' f'<summary><b>Fit Status</b></summary>{status_tab}</details>' '</details>' ) def _tabs(self): params_tab = make_pretty_table( ['Parameter', 'MLE', 'Error'], [(k, f'{v[0]:.4g}', f'{v[1]:.4g}') for k, v in self.mle.items()], ) stat_type = self._helper.statistic deviance = self.deviance ndata = self.ndata rows = [ [i, f'{stat_type[i]}', f'{deviance[i]:.2f}', ndata[i]] for i in self.ndata.keys() if i != 'total' ] rows.append( [ 'Total', 'stat/dof', f'{deviance["total"]:.2f}/{self.dof}', ndata['total'], ] ) names = ['Data', 'Statistic', 'Value', 'Channels'] stat_tab = make_pretty_table(names, rows) rows = [['AIC', f'{self.aic:.2f}'], ['BIC', f'{self.bic:.2f}']] names = ['Method', 'Value'] ic_tab = make_pretty_table(names, rows) return {'params': params_tab, 'stat': stat_tab, 'ic': ic_tab} @property def plot(self) -> MLEResultPlotter: if self._plotter is None: self._plotter = MLEResultPlotter(self) return self._plotter
[docs] def boot( self, n: int = 10000, seed: int | None = None, parallel: bool = True, n_parallel: int | None = None, progress: bool = True, update_rate: int = 50, ): """Preform parametric bootstrap. Parameters ---------- n : int, optional Number of parametric bootstraps based on the MLE. The default is 10000. seed : int, optional The seed of random number generator used in parametric bootstrap. parallel : bool, optional Whether to run simulation fit in parallel. The default is True. n_parallel : int, optional Number of parallel processes to use when `parallel` is ``True``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to display progress bar. The default is True. update_rate : int, optional The update rate of progress bar. The default is 50. """ n = int(n) n_parallel = get_parallel_number(n_parallel) if parallel and (n % n_parallel): n += n_parallel - n % n_parallel # reuse the previous result if all setup is the same if self._boot and self._boot.n == n and self._boot.seed == seed: return helper = self._helper seed = helper.seed['pred'] if seed is None else int(seed) params = {i: self._mle[i][0] for i in helper.params_names['free']} models = self._model_values # perform parametric bootstrap result = helper.simulate_and_fit( seed, params, models, n, parallel, n_parallel, progress, update_rate, 'Bootstrap', ) valid = result.pop('valid') result = jax.tree.map(lambda x: x[valid], result) self._boot = BootstrapResult( mle={k: v[0] for k, v in self.mle.items()}, data=result['data'], models=result['models'], params=result['params'], deviance=result['deviance'], p_value=jax.tree.map( lambda obs, sim: np.sum(sim >= obs, axis=0) / len(sim), self._deviance, result['deviance'], ), n=n, n_valid=np.sum(valid), seed=seed, )
@property def _params_dist(self) -> dict[str, jax.Array] | None: """Bootstrapped parameter distribution.""" boot = self._boot if boot is None: return None n = boot.n_valid - boot.n_valid % jax.local_device_count() return {k: v[:n] for k, v in boot.params.items()}
[docs] def covar( self, params: str | Sequence[str] | None = None, fn: dict[str, Callable] | None = None, method: Literal['hess', 'boot'] = 'hess', parallel: bool = True, ) -> ParamsCovar: """Calculate covariance matrix. Parameters ---------- params : str or sequence of str, optional Parameters to calculate covariance matrix. If not specified, calculate for parameters of interest. fn : dict A dict containing functions to calculate the covariance matrix. The keys are the names of the function results, and the values are the functions whose input is a dict of model parameters. method : {'hess', 'boot'}, optional Method used to calculate covariance. Available options are: * ``'hess'``: inverse of Hessian matrix from Minuit * ``'boot'``: calculate covariance based on bootstrap samples, :meth:`MLEResult.boot` must be called before using this method. The default is ``'hess'``. parallel : bool, optional Whether to evaluate `fn` in parallel when `method` is ``'boot'``. The default is True. Returns ------- ParamsCovar The covariance matrix. """ params_mle = {k: v[0] for k, v in self.mle.items()} params = check_params(params, self._helper) fn = self._check_fn(fn) if method == 'hess': if not fn: covar = np.array(self._covar) else: @jax.jit @jax.jacobian @jax.jit def jacobian(params_arr): params_dic = dict( zip(params_mle.keys(), params_arr, strict=True) ) fn_arr = jnp.array([f(params_dic) for f in fn.values()]) return jnp.hstack([params_arr, fn_arr]) jac = np.array(jacobian(jnp.array(list(params_mle.values())))) old_covar = np.array(self._covar) covar = jac @ old_covar @ jac.T elif method == 'boot': self._raise_if_no_boot() @jax.jit @jax.vmap @jax.jit def eval_fn(params): params_arr = jnp.array([params[k] for k in params_mle]) fn_arr = jnp.array([f(params) for f in fn.values()]) return jnp.hstack([params_arr, fn_arr]) if parallel: n_parallel = get_parallel_number(self._n_parallel) devices = create_device_mesh( mesh_shape=(n_parallel,), devices=jax.devices()[:n_parallel], ) mesh = Mesh(devices, axis_names=('i',)) pi = PartitionSpec('i') eval_fn = shard_map( f=eval_fn, mesh=mesh, in_specs=(pi,), out_specs=pi, check_rep=False, ) samples = eval_fn(self._params_dist) covar = np.cov(samples, rowvar=False) else: raise ValueError("method must be either 'hess' or 'boot'") names = tuple(params) + tuple(fn.keys()) var2pos = dict(zip(names, range(len(names)), strict=True)) matrix = CovarMatrix(var2pos) mask = np.array( [p in params for p in params_mle] + [True] * len(fn), dtype=bool, ) matrix[:] = np.array(covar)[:, mask][mask] return ParamsCovar(names=names, matrix=matrix)
[docs] def ci( self, cl: float | int = 1, params: str | Iterable[str] | None = None, fn: dict[str, Callable] | None = None, method: Literal['profile', 'boot'] = 'profile', rtol: float | dict[str, float] = 1e-6, parallel: bool = True, ) -> ConfidenceInterval: """Calculate confidence intervals. Parameters ---------- cl : float or int, optional Confidence level for the confidence interval. If 0 < `cl` < 1, the value is interpreted as the confidence level. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% confidence interval. The default is 1. params : str or sequence of str, optional Parameters to calculate confidence intervals. If not specified, calculate for parameters of interest. fn : dict, optional A dict containing functions to calculate the confidence intervals. The keys are the names of the function results, and the values are the functions whose input is a dict of model parameters. method : {'profile', 'boot'}, optional Method for calculating confidence intervals. Available options are: * ``'profile'``: use Minos algorithm of Minuit to find the confidence intervals based on the profile likelihood * ``'boot'``: use parametric bootstrap method to calculate the confidence intervals. :meth:`MLEResult.boot` must be called before using this method. The default is ``'profile'``. rtol : float, or dict of float, optional The relative tolerance in determining the value of composite parameters and `fn` when `method` is ``'profile'``. The default is 1e-6. parallel : bool, optional Whether to evaluate `fn` in parallel when `method` is ``'boot'``. The default is True. Returns ------- ConfidenceInterval The confidence intervals. """ cl = self._to_unit_cl(cl) params = check_params(params, self._helper) params_set = set(params) free = params_set.intersection(self._helper.params_names['free']) composite = params_set.intersection( self._helper.params_names['deterministic'] ) assert free | composite == params_set fn = self._check_fn(fn) rtol_keys = tuple(fn.keys()) + tuple(composite) if isinstance(rtol, float): rtol = dict.fromkeys(rtol_keys, rtol) else: rtol = jax.tree.map(float, dict(rtol)) for k in rtol_keys: rtol.setdefault(k, 1e-6) if np.any([i > 0.01 for i in rtol.values()]): raise ValueError('rtol must be less than 0.01') if method == 'profile': self._warn_invalid_fit() if not self._minuit.valid: intervals, status = self._ci_invalid(params_set | fn.keys()) else: res1 = self._ci_free(free, cl) if free else ({}, {}) if composite: def factory(k): def _(p): return p[k] return _ fn_composite = {k: factory(k) for k in composite} res2 = self._ci_fn(fn_composite, cl, rtol) else: res2 = ({}, {}) res3 = self._ci_fn(fn, cl, rtol) if fn else ({}, {}) intervals = res1[0] | res2[0] | res3[0] status = res1[1] | res2[1] | res3[1] elif method == 'boot': intervals, status = self._ci_boot(cl, params, fn, parallel) else: raise ValueError("method must be either 'profile' or 'boot'") params_mle = {k: v[0] for k, v in self._mle.items()} vars_names = params + list(fn.keys()) vars_mle = {k: v for k, v in params_mle.items() if k in vars_names} vars_mle |= {k: v(params_mle) for k, v in fn.items()} params_se = {k: v[1] for k, v in self._mle.items()} fn_covar = self.covar(params=(), fn=fn, method='hess') vars_se = {k: v for k, v in params_se.items() if k in vars_names} vars_se |= {k: np.sqrt(fn_covar.matrix[k, k]) for k, v in fn.items()} errors = { k: (intervals[k][0] - vars_mle[k], intervals[k][1] - vars_mle[k]) for k in vars_names } return ConfidenceInterval( mle=_format_result(vars_mle, vars_names), se=_format_result(vars_se, vars_names), intervals=_format_result(intervals, vars_names), errors=_format_result(errors, vars_names), cl=1.0 - 2.0 * stats.norm.sf(cl) if cl >= 1.0 else cl, method=method, status=status, )
def _ci_invalid(self, names: Iterable[str]): """Confidence interval of invalid fit.""" interval = {k: (float('nan'), float('nan')) for k in names} status = { k: { 'valid': (False, False), 'at_limit': (False, False), 'at_max_fcn': (False, False), 'new_min': (False, False), } for k in names } return interval, status def _ci_free(self, names: Iterable[str], cl: float | int): """Confidence interval of free parameters.""" self._minuit.minos(*names, cl=cl) mle_unconstr = self._minuit.values.to_dict() ci_unconstr = self._minuit.merrors # values of uninterested free parameters, in unconstrained space others = {k: v for k, v in mle_unconstr.items() if k not in names} # lower bound lower = self._helper.unconstr_dic_to_params_dic( {k: mle_unconstr[k] + ci_unconstr[k].lower for k in names} | others ) # upper bound upper = self._helper.unconstr_dic_to_params_dic( {k: mle_unconstr[k] + ci_unconstr[k].upper for k in names} | others ) interval = {k: (lower[k], upper[k]) for k in names} status = { k: { 'valid': (v.lower_valid, v.upper_valid), 'at_limit': (v.at_lower_limit, v.at_upper_limit), 'at_max_fcn': (v.at_lower_max_fcn, v.at_upper_max_fcn), 'new_min': (v.lower_new_min, v.upper_new_min), } for k, v in ci_unconstr.items() if k in names } return interval, status def _ci_fn( self, fn: dict[str, Callable], cl: float | int, rtol: dict[str, float], ): """Confidence intervals of function of free parameters.""" params_mle = {k: v[0] for k, v in self._mle.items()} fn_mle = {k: v(params_mle) for k, v in fn.items()} def get_minuit(name, mle, r) -> Minuit: loss = self._loss_factory(fn[name], r) grad = jax.jit(jax.grad(loss)) init = np.hstack([mle, self._minuit.values]) minuit = Minuit(loss, init, grad=grad) minuit.strategy = 2 minuit.migrad() return minuit def get_minuit_iter_rtol(name, mle) -> Minuit: rtol_desired = rtol[name] minuit0 = get_minuit(name, mle, rtol_desired) if not minuit0.accurate: # When profiling the likelihood, deviance difference for # 1-sigma confidence interval is 1, thus the deviance # varies slow within 1-sigma confidence interval. # The rtol should be less than 1% of the variance. fn_var = self.covar(params=(), fn={'_': fn[name]}).matrix[0, 0] rel_err = np.sqrt(fn_var) / np.abs(mle) rtol_max = np.min([0.01, 0.01 * rel_err, 100 * rtol_desired]) if rtol_desired < rtol_max: for r in np.geomspace(rtol_desired, rtol_max, num=15)[1:]: minuit = get_minuit(name, mle, r) if minuit.accurate: return minuit return minuit0 interval = {} status = {} for name, mle in fn_mle.items(): minuit = get_minuit_iter_rtol(name, mle) minuit.minos(0, cl=cl) ci = minuit.merrors[0] interval[name] = (mle + ci.lower, mle + ci.upper) status[name] = { 'valid': (ci.lower_valid, ci.upper_valid), 'at_limit': (ci.at_lower_limit, ci.at_upper_limit), 'at_max_fcn': (ci.at_lower_max_fcn, ci.at_upper_max_fcn), 'new_min': (ci.lower_new_min, ci.upper_new_min), } return interval, status def _loss_factory(self, fn: Callable, rtol: float): """Factory method to create joint loss of params and func of params. Parameters ---------- fn : Callable Function accepts model parameters and outputs a single value. rtol : float Relative tolerance of the function value. References ---------- .. [1] Eq.24 of https://doi.org/10.1007/s11222-021-10012-y .. [2] https://github.com/vemomoto/vemomoto/blob/master/ci_rvm/ci_rvm/ci_rvm.py#L1455 """ helper = self._helper params_free = helper.params_names['free'] @jax.jit def loss(x: np.ndarray): """Joint loss of params and func of params.""" unconstr_dic = dict(zip(params_free, x[1:], strict=True)) params = helper.unconstr_dic_to_params_dic(unconstr_dic) fn_value = fn(params) s1 = fn_value / x[0] - 1.0 s2 = (s1 * s1) / rtol return helper.deviance_total(x[1:]) + s2 return loss def _ci_boot( self, cl: float | int, params: Iterable[str], fn: dict[str, Callable] | None = None, parallel: bool = True, params_setting: dict[str, JAXArray] | None = None, ): """Bootstrap confidence interval.""" self._raise_if_no_boot() if params_setting is not None: params_setting = dict(params_setting) else: params_setting = {} boot = self._boot boot_params = self._params_dist nboot = len(list(boot_params.values())[0]) cl = self._to_unit_cl(cl) q = (0.5 - 0.5 * cl, 0.5 + 0.5 * cl) interval = { k: np.quantile(v, q=q).tolist() for k, v in boot_params.items() if k in params } status = { 'nboot': nboot, 'seed': int(boot.seed), 'dist': jax.tree.map( # get a copy of the distribution lambda x: x.copy(), jax.device_get(boot_params) ), 'params_setting': params_setting, } if fn is not None and fn: eval_fn = jax.jit( jax.vmap(lambda p: jax.tree.map(lambda f: f(p), fn)) ) if parallel: n_parallel = get_parallel_number(self._n_parallel) devices = create_device_mesh( mesh_shape=(n_parallel,), devices=jax.devices()[:n_parallel], ) mesh = Mesh(devices, axis_names=('i',)) pi = PartitionSpec('i') eval_fn = shard_map( f=eval_fn, mesh=mesh, in_specs=(pi,), out_specs=pi, check_rep=False, ) if params_setting: params_setting = { k: jnp.full(nboot, v) for k, v in params_setting.items() } fn_values = jax.device_get(eval_fn(boot_params | params_setting)) status['dist'] |= fn_values interval |= { k: np.quantile(v, q).tolist() for k, v in fn_values.items() } return interval, status def _warn_invalid_fit(self): if not self._minuit.valid: warnings.warn('fit must be valid to calculate confidence interval') def _raise_if_no_boot(self): if self._boot is None: raise RuntimeError( 'before using the bootstrap method, ' 'MLEResult.boot(...) must be called' ) def _intensity_ci( self, egrid: JAXArray, energy: bool, cl: float | int, converter: Callable, method: Literal['profile', 'boot'], comps: bool, params: dict[str, JAXArray] | None, ) -> dict[str, Q | float]: """Calculate confidence interval of flux. Parameters ---------- egrid : array-like Energy grid used in trapezoidal rule. energy : bool Whether the intensity is based on energy flux. cl : float or int Confidence level. converter : callable Function to convert the flux into desired intensity. method : {'profile', 'boot'} Method used to calculate confidence interval. comps : bool Whether to return the result of each component. params : dict, optional Parameters dict to overwrite the bootstrap parameters. Returns ------- dict The confidence interval of intensity. """ cl = self._to_unit_cl(cl) mle_params = {k: v[0] for k, v in self._mle.items()} if params is not None: mle_params |= params mle_flux = self._flux_fn(egrid, mle_params, energy, comps) mapping = self._model_mapping fn = jax.jit(lambda p: self._flux_fn(egrid, p, energy, comps)) if comps: def factory(d, c): def _(p): return fn(p)[d][c] return _ fn_dic = { f'{d}_{c}': factory(d, c) for d in mapping.values() for c in mle_flux[d].keys() } else: def factory(d): def _(p): return fn(p)[d] return _ fn_dic = {d: factory(d) for d in mapping.values()} cov = self.covar(params=(), fn=fn_dic, method='hess') se = {k: np.sqrt(cov.matrix[k, k]) for k in fn_dic.keys()} if method == 'profile': if params is not None: warnings.warn('params is ignored when using profile method') # use log transform to stabilize the profile likelihood def transform(f): def _(p): return jnp.log(f(p)) return _ fn_dic = jax.tree.map(transform, fn_dic) rtol = dict.fromkeys(fn_dic.keys(), 1e-08) intervals, status = self._ci_fn(fn_dic, cl, rtol=rtol) intervals = jax.tree.map(jnp.exp, intervals) elif method == 'boot': intervals, status = self._ci_boot(cl, [], fn_dic, True, params) else: raise ValueError("method must be either 'profile' or 'boot'") if comps: se = { k: {c: se[f'{k}_{c}'] for c in mle_flux[k].keys()} for k in mapping.values() } intervals = { k: {c: intervals[f'{k}_{c}'] for c in mle_flux[k].keys()} for k in mapping.values() } if method == 'profile': status = { k: {c: status[f'{k}_{c}'] for c in mle_flux[k].keys()} for k in mapping.values() } else: dist = status['dist'] status['dist'] = { k: {c: dist[f'{k}_{c}'] for c in mle_flux[k].keys()} for k in mapping.values() } if energy: unit = u.Unit('erg cm^-2 s^-1') else: unit = u.Unit('ph cm^-2 s^-1') convert = lambda x: (x if x is None else converter(x * unit)) intervals = {k: intervals[v] for k, v in mapping.items()} se = {k: se[v] for k, v in mapping.items()} errors = jax.tree.map( lambda x, y: (y[0] - x, y[1] - x), mle_flux, intervals ) if method == 'profile': status = {k: status[v] for k, v in mapping.items()} else: status['dist'] = jax.tree.map(convert, status['dist']) return { 'mle': jax.tree.map(convert, mle_flux), 'se': jax.tree.map(convert, se), 'intervals': jax.tree.map(convert, intervals), 'errors': jax.tree.map(convert, errors), 'cl': cl, 'status': status, }
[docs] def flux( self, emin: float | int, emax: float | int, energy: bool = True, cl: float | int = 1, method: Literal['profile', 'boot'] = 'profile', ngrid: int = 1000, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, ) -> MLEFlux: r"""Calculate the flux of model. .. warning:: The flux is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin : float or int Minimum value of energy range, in units of keV. emax : float or int Maximum value of energy range, in units of keV. energy : bool, optional When True, calculate energy flux in units of erg cm⁻² s⁻¹; otherwise calculate photon flux in units of ph cm⁻² s⁻¹. The default is True. cl : float or int, optional Confidence level for the confidence interval. If 0 < `cl` < 1, the value is interpreted as the confidence level. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% confidence interval. The default is 1. method : {'profile', 'boot'}, optional Method for calculating confidence intervals. Available options are: * ``'profile'``: use Minos algorithm of Minuit to find the confidence intervals based on the profile likelihood * ``'boot'``: use parametric bootstrap method to calculate the confidence intervals. :meth:`MLEResult.boot` must be called before using this method. The default is ``'profile'``. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. Returns ------- MLEFlux The flux of the model. """ if log: egrid = jnp.geomspace(emin, emax, ngrid) else: egrid = jnp.linspace(emin, emax, ngrid) converter = lambda x: x flux = self._intensity_ci( egrid, energy, cl, converter, method, comps, params ) return MLEFlux(emin, emax, bool(energy), **flux, method=method)
[docs] def lumin( self, emin_rest: float | int, emax_rest: float | int, z: float | int, cl: float | int = 1, method: Literal['profile', 'boot'] = 'profile', ngrid: int = 1000, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, cosmo: LambdaCDM = Planck18, ) -> MLELumin: """Calculate the luminosity of model. .. warning:: The luminosity is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin_rest : float or int Minimum value of rest-frame energy range, in units of keV. emax_rest : float or int Maximum value of rest-frame energy range, in units of keV. z : float or int Redshift of the source. cl : float or int, optional Confidence level for the confidence interval. If 0 < `cl` < 1, the value is interpreted as the confidence level. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% confidence interval. The default is 1. method : {'profile', 'boot'}, optional Method for calculating confidence intervals. Available options are: * ``'profile'``: use Minos algorithm of Minuit to find the confidence intervals based on the profile likelihood * ``'boot'``: use parametric bootstrap method to calculate the confidence intervals. :meth:`MLEResult.boot` must be called before using this method. The default is ``'profile'``. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. cosmo : LambdaCDM, optional Cosmology model used to calculate luminosity. The default is Planck18. Returns ------- MLELumin The luminosity of the model. """ if log: egrid = jnp.geomspace(emin_rest, emax_rest, ngrid) / (1.0 + z) else: egrid = jnp.linspace(emin_rest, emax_rest, ngrid) / (1.0 + z) factor = 4.0 * np.pi * cosmo.luminosity_distance(z) ** 2 to_lumin = lambda x: (x * factor).to('erg s^-1') lumin = self._intensity_ci( egrid, True, cl, to_lumin, method, comps, params ) return MLELumin(emin_rest, emax_rest, z, cosmo, **lumin, method=method)
[docs] def eiso( self, emin_rest: float | int, emax_rest: float | int, z: float | int, duration: float | int, cl: float | int = 1, method: Literal['profile', 'boot'] = 'profile', ngrid: int = 1000, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, cosmo: LambdaCDM = Planck18, ) -> MLEEiso: r"""Calculate the isotropic emission energy of model. .. warning:: The :math:`E_\mathrm{iso}` is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin_rest : float or int Minimum value of rest-frame energy range, in units of keV. emax_rest : float or int Maximum value of rest-frame energy range, in units of keV. z : float or int Redshift of the source. duration : float or int Observed duration of the source, in units of seconds. cl : float or int, optional Confidence level for the confidence interval. If 0 < `cl` < 1, the value is interpreted as the confidence level. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% confidence interval. The default is 1. method : {'profile', 'boot'}, optional Method for calculating confidence intervals. Available options are: * ``'profile'``: use Minos algorithm of Minuit to find the confidence intervals based on the profile likelihood * ``'boot'``: use parametric bootstrap method to calculate the confidence intervals. :meth:`MLEResult.boot` must be called before using this method The default is ``'profile'``. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. cosmo : LambdaCDM, optional Cosmology model used to calculate luminosity. The default is Planck18. Returns ------- MLEEiso The isotropic emission energy of the model. """ if log: egrid = jnp.geomspace(emin_rest, emax_rest, ngrid) / (1.0 + z) else: egrid = jnp.linspace(emin_rest, emax_rest, ngrid) / (1.0 + z) # This includes correction for energy redshift and time dilation. factor = 4.0 * np.pi * cosmo.luminosity_distance(z) ** 2 factor *= duration / (1 + z) * u.s to_eiso = lambda x: (x * factor).to('erg') eiso = self._intensity_ci( egrid, True, cl, to_eiso, method, comps, params ) return MLEEiso( emin_rest, emax_rest, z, duration, cosmo, **eiso, method=method )
@property def gof(self) -> dict[str, float]: if self._boot is None: raise RuntimeError('MLEResult.boot() must be called to assess gof') p_value = self._boot.p_value p_value = p_value['group'] | {'total': p_value['total']} return {k: float(p_value[k]) for k in self.ndata.keys()} @property def mle(self) -> dict[str, tuple[float, float]]: """MLE and error of parameters.""" return _format_result(self._mle, self._helper.params_names['all']) @property def deviance(self) -> dict[str, float]: """Deviance of the model at MLE.""" stat = self._deviance['group'] | {'total': self._deviance['total']} stat = {i: float(stat[i]) for i in (*self._helper.data_names, 'total')} return stat @property def aic(self) -> float: """Akaike information criterion with sample size correction.""" return self._aic @property def bic(self) -> float: """Bayesian information criterion.""" return self._bic @property def status(self) -> FMin: """Fit status of Minuit.""" return self._minuit.fmin
[docs] class PosteriorResult(FitResult): """Result obtained from Bayesian fit.""" _plotter: PosteriorResultPlotter | None = None _idata: az.InferenceData _deviance: dict | None = None _mle_result: dict | None = None _ppc: PPCResult | None = None _psislw_: DataArray | None = None _loo: az.stats.stats_utils.ELPDData | None = None _waic: az.stats.stats_utils.ELPDData | None = None _rhat: dict[str, float] | None = None _divergence: int | None = None _pit: dict[str, tuple] | None = None _params: dict[str, JAXArray] | None = None _info_tabs: dict | None = None def __init__( self, helper: Helper, idata: InferenceData, ml_optimize: Callable, sampler_state: Any = None, ): super().__init__(helper) self._idata = idata self._ml_optimize = ml_optimize self._sampler_state = sampler_state def __repr__(self): tabs = self._tabs() return ( f'Parameters\n{tabs["params"]}\n\n' f'Fit Statistics\n{tabs["stat"]}\n\n' f'Information Criterion\n{tabs["ic"]}\n\n' f'Pareto k Diagnostic\n{tabs["k"]}\n' ) def _repr_html_(self): """The repr in Jupyter notebook environment.""" tabs = self._tabs() params_tab = tabs['params'].get_html_string(format=True) stat_tab = tabs['stat'].get_html_string(format=True) ic_tab = tabs['ic'].get_html_string(format=True) k_tab = tabs['k'].get_html_string(format=True) return ( '<details open><summary><b>Posterior Result</b></summary>' '<details open style="padding-left: 1em">' f'<summary><b>Parameters</b></summary>{params_tab}</details>' '<details open style="padding-left: 1em">' f'<summary><b>Statistics</b></summary>{stat_tab}</details>' '<details open style="padding-left: 1em">' '<summary><b>Information Criterion</b></summary>' f'{ic_tab}</details>' '<details open style="padding-left: 1em">' f'<summary><b>Pareto k Diagnostic</b></summary>{k_tab}</details>' '</details>' ) def _tabs(self): if self._info_tabs is not None: return self._info_tabs params_name = self._helper.params_names['all'] params = self.idata['posterior'][params_name] mean = params.mean() std = params.std(ddof=1) median = params.median() ci = params.quantile(0.5 + 0.683 * np.array([-0.5, 0.5])) - median ess = self.ess rhat = self.rhat rows = [ [ k, f'{mean[k]:.3g}', f'{std[k]:.3g}', f'{median[k]:.3g}', f'[{ci[k][0]:.3g}, {ci[k][1]:.3g}]', f'{ess[k]}', f'{rhat[k]:.2f}' if not np.isnan(rhat[k]) else 'N/A', ] for k in params_name ] names = [ 'Parameter', 'Mean', 'StdDev', 'Median', '68.3% Quantile', 'ESS', 'Rhat', ] params_tab = make_pretty_table(names, rows) stat_type = self._helper.statistic deviance = self.deviance rows = [ [ i, stat_type[i], f'{deviance[i]["mean"]:.2f}', f'{deviance[i]["median"]:.2f}', j, ] for i, j in self.ndata.items() if i != 'total' ] rows.append( [ 'Total', 'stat/dof', f'{deviance["total"]["mean"]:.2f}/{self.dof}', f'{deviance["total"]["median"]:.2f}/{self.dof}', self.ndata['total'], ] ) names = [ 'Data', 'Statistic', 'Mean', 'Median', 'Channels', ] stat_tab = make_pretty_table(names, rows) loo = self.loo waic = self.waic rows = [ [ 'LOOIC', f'{loo.elpd_loo:.2f} ± {loo.se:.2f}', f'{loo.p_loo:.2f}', ], [ 'WAIC', f'{waic.elpd_waic:.2f} ± {waic.se:.2f}', f'{waic.p_waic:.2f}', ], ] names = ['Method', 'Deviance', 'p'] ic_tab = make_pretty_table(names, rows) good_k = self.loo.good_k ranges = [f'(-Inf, {good_k:.2f}]', f'({good_k:.2f}, 1]', '(1, Inf)'] flags = ['good', 'bad', 'very bad'] bins = np.asarray([-np.inf, good_k, 1, np.inf]) counts, *_ = np.histogram(loo.pareto_k.values, bins) pct = [f'{i:.1%}' for i in counts / np.sum(counts)] rows = list(zip(ranges, flags, counts, pct, strict=True)) names = ['Range', 'Flag', 'Count', 'Pct.'] k_tab = make_pretty_table(names, rows) self._info_tabs = { 'params': params_tab, 'stat': stat_tab, 'ic': ic_tab, 'k': k_tab, } return self._info_tabs @property def plot(self) -> PosteriorResultPlotter: if self._plotter is None: self._plotter = PosteriorResultPlotter(self) return self._plotter
[docs] def covar( self, params: str | Iterable[str] | None = None, fn: dict[str, Callable] | None = None, parallel: bool = True, ) -> ParamsCovar: """Calculate the covariance matrix. Parameters ---------- params : str or sequence of str, optional Parameters to calculate covariance matrix. If not specified, calculate for all parameters. fn : dict, optional A dict containing functions to calculate the covariance matrix. The keys are the names of the function results, and the values are the functions whose input is a dict of model parameters. parallel : bool, optional Whether to use parallel computation for `fn`. The default is True. Returns ------- ParamsCovar The covariance matrix. """ params = check_params(params, self._helper) fn = self._check_fn(fn) params_dist = self._params_dist @jax.jit @jax.vmap @jax.jit def eval_fn(params): params_arr = jnp.array([params[k] for k in params_dist]) fn_arr = jnp.array([f(params) for f in fn.values()]) return jnp.hstack([params_arr, fn_arr]) if parallel: n_parallel = get_parallel_number(self._n_parallel) devices = create_device_mesh( mesh_shape=(n_parallel,), devices=jax.devices()[:n_parallel], ) mesh = Mesh(devices, axis_names=('i',)) pi = PartitionSpec('i') eval_fn = shard_map( f=eval_fn, mesh=mesh, in_specs=(pi,), out_specs=pi, check_rep=False, ) samples = eval_fn(params_dist) covar = np.cov(samples, rowvar=False) names = tuple(params) + tuple(fn.keys()) var2pos = dict(zip(names, range(len(names)), strict=True)) matrix = CovarMatrix(var2pos) mask = np.array( [p in params for p in params_dist] + [True] * len(fn), dtype=bool, ) matrix[:] = np.array(covar)[:, mask][mask] return ParamsCovar(names=names, matrix=matrix)
[docs] def ci( self, cl: float | int = 1, params: str | Iterable[str] | None = None, fn: dict[str, Callable] | None = None, hdi: bool = False, parallel: bool = True, ) -> CredibleInterval: """Calculate credible intervals. Parameters ---------- cl : float or int, optional The credible level of samples within the credible interval. If 0 < `cl` < 1, the value is interpreted as the probability mass. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% credible interval. The default is 1. params : str or sequence of str, optional Parameters to calculate confidence intervals. If not specified, calculate for parameters of interest. fn : dict, optional A dict containing functions to calculate the confidence intervals. The keys are the names of the function results, and the values are the functions whose input is a dict of model parameters. hdi : bool, optional Whether to return the highest density interval. The default is False, which means an equal tailed interval is returned. parallel : bool, optional Whether to use parallel computation for `fn`. The default is True. Returns ------- CredibleInterval The credible interval. """ cl = self._to_unit_cl(cl) fn = self._check_fn(fn) params = check_params(params, self._helper) if hdi: median = self.idata['posterior'].median() median = { k: float(v) for k, v in median.data_vars.items() if k in params } interval = az.hdi(self.idata, cl, var_names=params) interval = { k: (float(v[0]), float(v[1])) for k, v in interval.data_vars.items() } else: q = [0.5, 0.5 - cl / 2.0, 0.5 + cl / 2.0] quantile = self.idata['posterior'].quantile(q) quantile = { k: v for k, v in quantile.data_vars.items() if k in params } median = {k: float(v[0]) for k, v in quantile.items()} interval = { k: (float(v[1]), float(v[2])) for k, v in quantile.items() } mean = {p: self.mean[p] for p in params} std = {p: self.std[p] for p in params} dist = { k: v.data for k, v in self.idata['posterior'][params].data_vars.items() } if fn: mean_, std_, median_, interval_, dist_ = self._ci_fn( fn, cl, hdi, parallel ) mean |= mean_ std |= std_ median |= median_ interval |= interval_ dist |= dist_ vars_names = tuple(params) + tuple(fn.keys()) error = { k: (interval[k][0] - median[k], interval[k][1] - median[k]) for k in vars_names } return CredibleInterval( median=_format_result(median, vars_names), intervals=_format_result(interval, vars_names), errors=_format_result(error, vars_names), cl=cl, method='HDI' if hdi else 'ETI', dist=dist, )
def _ci_fn( self, fn: dict[str, Callable], cl: float | int, hdi: bool, parallel: bool = True, params_setting: dict[str, JAXArray] | None = None, ): if params_setting is not None: params_setting = dict(params_setting) else: params_setting = {} params = self._params_dist cl = self._to_unit_cl(cl) eval_fn = jax.jit(jax.vmap(lambda p: jax.tree.map(lambda f: f(p), fn))) if parallel: n_parallel = get_parallel_number(self._n_parallel) devices = create_device_mesh( mesh_shape=(n_parallel,), devices=jax.devices()[:n_parallel], ) mesh = Mesh(devices, axis_names=('i',)) pi = PartitionSpec('i') eval_fn = shard_map( f=eval_fn, mesh=mesh, in_specs=(pi,), out_specs=pi, check_rep=False, ) if params_setting: n = len(list(params.values())[0]) params_setting = { k: jnp.full(n, v) for k, v in params_setting.items() } dist = jax.device_get(eval_fn(params | params_setting)) if hdi: median = jax.tree.map(np.median, dist) interval = az.hdi(dist, cl) interval = jax.tree.map( lambda x: (float(x[0]), float(x[1])), interval.data_vars, ) else: q = [0.5, 0.5 - cl / 2.0, 0.5 + cl / 2.0] quantile = jax.tree.map(lambda x: np.quantile(x, q), dist) median = jax.tree.map(lambda x: float(x[0]), quantile) interval = jax.tree.map( lambda x: (float(x[1]), float(x[2])), quantile ) mean = jax.tree.map(np.mean, dist) std = jax.tree.map(np.std, dist) return mean, std, median, interval, dist def _intensity_ci( self, egrid: JAXArray, energy: bool, cl: float | int, converter: Callable, hdi: bool, comps: bool, params: dict[str, JAXArray] | None, ) -> dict[str, Q | float]: """Calculate confidence interval of flux. Parameters ---------- egrid : array-like Energy grid used in trapezoidal rule. energy : bool Whether the intensity is based on energy flux. cl : float or int Credible level. converter : callable Function to convert the flux into desired intensity. hdi : bool Whether to return the highest density interval. comps : bool Whether to return the result of each component. params : dict, optional Parameters dict to overwrite the posterior parameters. Returns ------- dict The credible interval of intensity. """ cl = self._to_unit_cl(cl) fn = jax.jit(lambda p: self._flux_fn(egrid, p, energy, comps)) mean, std, median, intervals, dist = self._ci_fn( {'intensity': fn}, cl, hdi, True, params ) mean = mean['intensity'] std = std['intensity'] median = median['intensity'] intervals = intervals['intensity'] dist = dist['intensity'] if energy: unit = u.Unit('erg cm^-2 s^-1') else: unit = u.Unit('ph cm^-2 s^-1') convert = lambda x: (x if x is None else converter(x * unit)) errors = jax.tree.map( lambda x, y: (y[0] - x, y[1] - x), median, intervals ) return { 'mean': jax.tree.map(convert, mean), 'std': jax.tree.map(convert, std), 'median': jax.tree.map(convert, median), 'intervals': jax.tree.map(convert, intervals), 'errors': jax.tree.map(convert, errors), 'cl': cl, 'dist': jax.tree.map(convert, dist), }
[docs] def flux( self, emin: float | int, emax: float | int, cl: float | int = 1, energy: bool = True, ngrid: int = 1000, hdi: bool = False, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, ) -> PosteriorFlux: r"""Calculate the flux of model. .. warning:: The flux is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin : float or int Minimum value of energy range, in units of keV. emax : float or int Maximum value of energy range, in units of keV. cl : float or int, optional The credible level of samples within the credible interval. If 0 < `cl` < 1, the value is interpreted as the probability mass. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% credible interval. The default is 1. energy : bool, optional When True, calculate energy flux in units of erg cm⁻² s⁻¹; otherwise calculate photon flux in units of ph cm⁻² s⁻¹. The default is True. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- hdi : bool, optional Whether to return the highest density interval. The default is False, which means an equal tailed interval is returned. comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. Ignored when `method` is ``'profile'``. Returns ------- PosteriorFlux The flux of the model. """ if log: egrid = jnp.geomspace(emin, emax, ngrid) else: egrid = jnp.linspace(emin, emax, ngrid) flux = self._intensity_ci( egrid, energy, cl, lambda x: x, hdi, comps, params ) return PosteriorFlux( emin, emax, bool(energy), **flux, method='HDI' if hdi else 'ETI' )
[docs] def lumin( self, emin_rest: float | int, emax_rest: float | int, z: float | int, cl: float | int = 1, ngrid: int = 1000, hdi: bool = False, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, cosmo: LambdaCDM = Planck18, ) -> PosteriorLumin: """Calculate the luminosity of model. .. warning:: The luminosity is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin_rest : float or int Minimum value of rest-frame energy range, in units of keV. emax_rest : float or int Maximum value of rest-frame energy range, in units of keV. z : float or int Redshift of the source. cl : float or int, optional The credible level of samples within the credible interval. If 0 < `cl` < 1, the value is interpreted as the probability mass. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% credible interval. The default is 1. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- hdi : bool, optional Whether to return the highest density interval. The default is False, which means an equal tailed interval is returned. comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. Ignored when `method` is ``'profile'``. cosmo : LambdaCDM, optional Cosmology model used to calculate luminosity. The default is Planck18. Returns ------- PosteriorLumin The luminosity of the model. """ if log: egrid = jnp.geomspace(emin_rest, emax_rest, ngrid) / (1.0 + z) else: egrid = jnp.linspace(emin_rest, emax_rest, ngrid) / (1.0 + z) z = float(z) factor = 4.0 * np.pi * cosmo.luminosity_distance(z) ** 2 to_lumin = lambda x: (x * factor).to('erg s^-1') lumin = self._intensity_ci( egrid, True, cl, to_lumin, hdi, comps, params ) return PosteriorLumin( emin_rest, emax_rest, z, cosmo, **lumin, method='HDI' if hdi else 'ETI', )
[docs] def eiso( self, emin_rest: float | int, emax_rest: float | int, z: float | int, duration: float | int, cl: float | int = 1, ngrid: int = 1000, hdi: bool = False, comps: bool = False, log: bool = True, params: dict[str, float | int] | None = None, cosmo: LambdaCDM = Planck18, ) -> PosteriorEiso: r"""Calculate the isotropic emission energy of model. .. warning:: The :math:`E_\mathrm{iso}` is calculated by trapezoidal rule, and is accurate only if enough numbers of energy grids are used. Parameters ---------- emin_rest : float or int Minimum value of rest-frame energy range, in units of keV. emax_rest : float or int Maximum value of rest-frame energy range, in units of keV. z : float or int Redshift of the source. duration : float or int Observed duration of the source, in units of seconds. cl : float or int, optional The credible level of samples within the credible interval. If 0 < `cl` < 1, the value is interpreted as the probability mass. If `cl` >= 1, it is interpreted as the number of standard deviations. For example, ``cl=1`` produces a 1-sigma or 68.3% credible interval. The default is 1. ngrid : int, optional The energy grid number to use in integration. The default is 1000. Other Parameters ---------------- hdi : bool, optional Whether to return the highest density interval. The default is False, which means an equal tailed interval is returned. comps : bool, optional Whether to return the result of each component. The default is False. log : bool, optional Whether to use logarithmically regular energy grid. The default is True. params : dict, optional Parameters dict to overwrite the fitted parameters. Ignored when `method` is ``'profile'``. cosmo : LambdaCDM, optional Cosmology model used to calculate luminosity. The default is Planck18. Returns ------- PosteriorEiso The isotropic emission energy of the model. """ if log: egrid = jnp.geomspace(emin_rest, emax_rest, ngrid) / (1.0 + z) else: egrid = jnp.linspace(emin_rest, emax_rest, ngrid) / (1.0 + z) # This includes correction for energy redshift and time dilation. z = float(z) factor = 4.0 * np.pi * cosmo.luminosity_distance(z) ** 2 factor *= duration / (1 + z) * u.s to_eiso = lambda x: (x * factor).to('erg') eiso = self._intensity_ci(egrid, True, cl, to_eiso, hdi, comps, params) return PosteriorEiso( emin_rest, emax_rest, z, duration, cosmo, **eiso, method='HDI' if hdi else 'ETI', )
[docs] def ppc( self, n: int = 10000, seed: int | None = None, parallel: bool = True, n_parallel: int | None = None, progress: bool = True, update_rate: int = 50, ): """Perform posterior predictive check. Parameters ---------- n : int, optional The number of posterior predictions. The default is 10000. seed : int, optional The seed of random number generator used in posterior predictions. parallel : bool, optional Whether to run simulation fit in parallel. The default is True. n_parallel : int, optional Number of parallel processes to use when `parallel` is ``True``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to display progress bar. The default is True. update_rate : int, optional The update rate of progress bar. The default is 50. """ n = int(n) n_parallel = get_parallel_number(n_parallel) if parallel and (n % n_parallel): n += n_parallel - n % n_parallel # reuse the previous result if all setup is the same if self._ppc and self._ppc.n == n and self._ppc.seed == seed: return helper = self._helper free_params = helper.params_names['free'] seed = helper.seed['pred'] if seed is None else int(seed) # randomly select n samples from posterior rng = np.random.default_rng(seed) idata = self.idata i = rng.integers(0, idata['posterior'].chain.size, n) j = rng.integers(0, idata['posterior'].draw.size, n) params = { k: v.values[i, j] for k, v in idata['posterior'][free_params].data_vars.items() } models = { k: v.values[i, j] for k, v in helper.get_models(idata['posterior']).items() } # perform ppc result = helper.simulate_and_fit( seed, params, models, 1, parallel, n_parallel, progress, update_rate, 'PPC', ) valid = result.pop('valid') result = jax.tree.map(lambda x: x[valid], result) self._ppc = PPCResult( params_rep=params, models_rep=models, data=result['data'], params_fit=result['params'], models_fit=result['models'], deviance=result['deviance'], p_value=jax.tree.map( lambda obs, sim: np.sum(sim >= obs, axis=0) / len(sim), self._mle['deviance'], result['deviance'], ), n=n, n_valid=np.sum(valid), seed=seed, )
@property def gof(self) -> dict[str, float]: if self._ppc is None: raise RuntimeError( 'PosteriorResult.ppc() must be called to assess gof' ) p_value = self._ppc.p_value p_value = p_value['group'] | {'total': p_value['total']} return {k: float(p_value[k]) for k in self.ndata.keys()} @property def _mle(self): """MLE result.""" if self._mle_result is None: mle_result = {} helper = self._helper # MLE information of the model free_params = helper.params_names['free'] mle_idx = self.idata['log_likelihood']['total'].argmax(...) init = self.idata['posterior'][free_params].sel(**mle_idx) init = {k: v.values for k, v in init.data_vars.items()} init = helper.constr_dic_to_unconstr_arr(init) mle_unconstr = self._ml_optimize(init, throw=False)[0] # MLE of model params in constrained space mle, cov = jax.device_get(helper.get_mle(mle_unconstr)) err = np.sqrt(np.diagonal(cov)) params_names = helper.params_names['all'] mle_result['params'] = dict( zip(params_names, zip(mle, err, strict=True), strict=True) ) sites = jax.device_get(jax.jit(helper.get_sites)(mle_unconstr)) # model deviance at MLE loglike = sites['loglike'] # drop unnecessary terms loglike.pop('data') loglike.pop('channels') mle_result['deviance'] = jax.tree.map(lambda x: -2.0 * x, loglike) # model values at MLE mle_result['models'] = sites['models'] self._mle_result = mle_result return self._mle_result @property def mle(self) -> dict[str, tuple[float, float]]: """MLE parameters.""" return dict(self._mle['params']) @property def idata(self) -> az.InferenceData: """ArviZ InferenceData.""" return self._idata @property def sampler_state(self) -> Any: """The sampler state at the end of the sampling phase.""" return self._sampler_state def _compute_stat( self, cache_attr: str, stat_fn: Callable ) -> dict[str, float]: stat = getattr(self, cache_attr, None) if stat is None: params_name = self._helper.params_names['all'] stat = stat_fn(self.idata['posterior'][params_name]) stat = {k: float(v) for k, v in stat.items()} setattr(self, cache_attr, stat) return stat @property def mean(self) -> dict[str, float]: """Mean of parameter samples.""" return self._compute_stat('_mean', lambda x: x.mean()) @property def std(self) -> dict[str, float]: """Standard deviation of parameter samples.""" return self._compute_stat('_std', lambda x: x.std(ddof=1)) @property def median(self) -> dict[str, float]: """Median of parameter samples.""" return self._compute_stat('_median', lambda x: x.median()) @property def _params_dist(self) -> dict[str, JAXArray]: """Posterior samples used for further analysis, the sample size is truncated to be less than or equal to n_max=10000.""" if self._params is not None: return self._params n_max = 10000 post = self.idata['posterior'][self._helper.params_names['free']] n = post.chain.size * post.draw.size if n > n_max: n = n_max - n_max % jax.local_device_count() rng = np.random.default_rng(self._helper.seed['mcmc']) i = rng.integers(0, post.chain.size, n) j = rng.integers(0, post.draw.size, n) post = {k: v.values[i, j] for k, v in post.items()} else: post = {k: np.hstack(v.values) for k, v in post.data_vars.items()} self._params = post return self._params @property def deviance(self) -> dict: """Mean and median of model deviance.""" if self._deviance is None: stat_keys = { i: f'{i}_total' if i != 'total' else i for i in self.ndata.keys() } keys = list(stat_keys.values()) deviance = -2.0 * self.idata['log_likelihood'][keys] deviance_mean = deviance.mean() deviance_median = deviance.median() self._deviance = { k: { 'mean': float(deviance_mean[v]), 'median': float(deviance_median[v]), } for k, v in stat_keys.items() } return self._deviance @property def reff(self) -> float: """Relative MCMC efficiency.""" return float(self._idata['ess']['reff']) @property def ess(self) -> dict[str, int]: """Effective MCMC sample size.""" return { k: int(v) for k, v in self._idata['ess'].data_vars.items() if k != 'reff' } @property def rhat(self) -> dict[str, float]: """Computes split R-hat over MCMC chains. In general, only fully trust the sample if R-hat is less than 1.01. In the early workflow, R-hat below 1.1 is often sufficient. See [1]_ for more information. References ---------- .. [1] https://arxiv.org/abs/1903.08008 """ if self._rhat is None: params_names = self._helper.params_names['all'] posterior = self.idata['posterior'][params_names] if len(posterior['chain']) == 1: rhat = {k: float('nan') for k in posterior.data_vars.keys()} else: rhat = { k: float(v.values) for k, v in az.rhat(posterior).data_vars.items() } self._rhat = rhat return self._rhat @property def divergence(self) -> int: """Number of divergent samples.""" if self._divergence is None: if 'sample_stats' in self.idata: n = int(self.idata['sample_stats']['diverging'].sum()) else: n = 0 self._divergence = n return self._divergence @property def waic(self) -> ELPDData: """The widely applicable information criterion (WAIC). Estimates the expected log point-wise predictive density (elpd) using WAIC. Also calculates the WAIC's standard error and the effective number of parameters. See [1]_ and [2]_ for more information. References ---------- .. [1] https://arxiv.org/abs/1507.04544 .. [2] https://arxiv.org/abs/1004.2316 """ if self._waic is None: self._waic = az.waic( self.idata, var_name='channels', scale='deviance' ) return self._waic @property def loo(self) -> ELPDData: """Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Estimates the expected log point-wise predictive density (elpd) using PSIS-LOO-CV. Also calculates LOO's standard error and the effective number of parameters. For more information, see [1]_, [2]_ and [3]_. References ---------- .. [1] https://avehtari.github.io/modelselection/CV-FAQ.html .. [2] https://arxiv.org/abs/1507.04544 .. [3] https://arxiv.org/abs/1507.02646 """ if self._loo is None: self._loo = az.loo( self.idata, var_name='channels', reff=self.reff, scale='deviance', ) return self._loo @property def lnZ(self) -> tuple[float, float]: """Log model evidence and error.""" lnZ = float(self._idata['evidence']['lnZ']) lnZ_error = float(self._idata['evidence']['lnZ_error']) return lnZ, lnZ_error @property def _psislw(self) -> DataArray: if self._psislw_ is None: idata = self.idata reff = self.reff stack_kwargs = {'__sample__': ('chain', 'draw')} log_weights, kss = az.psislw( -idata['log_likelihood']['channels'].stack(**stack_kwargs), reff, ) self._psislw_ = log_weights return self._psislw_ def _loo_expectation(self, values: DataArray, data: str) -> DataArray: """Computes weighted expectations using the PSIS weights. Notes ----- The expectations estimated assume that the PSIS approximation is working well. A small Pareto k estimate is necessary, but not sufficient to give reliable estimates. Parameters ---------- values : DataArray Values to compute the expectation. data : str The data name. Returns ------- DataArray The expectation of the values. """ assert data in self._helper.data_names channel = self._helper.channels[f'{data}_channel'] log_weights = self._psislw.sel(channel=channel) log_weights = log_weights.rename({'channel': f'{data}_channel'}) log_expectation = log_weights + np.log(np.abs(values)) weighted = np.sign(values) * np.exp(log_expectation) return weighted.sum(dim='__sample__') @property def _loo_pit(self) -> dict[str, tuple]: """Leave-one-out probability integral transform.""" if self._pit is not None: return self._pit idata = self.idata helper = self._helper stack_kwargs = {'__sample__': ('chain', 'draw')} y_hat = idata['posterior_predictive']['channels'].stack(**stack_kwargs) loo_pit = az.loo_pit( y=idata['observed_data']['channels'], y_hat=y_hat, log_weights=self._psislw, ) loo_pit = { name: loo_pit.sel(channel=data.channel).values for name, data in helper.data.items() } discrete_stats = {'cstat', 'pstat', 'wstat'} data_stats = helper.statistic has_discrete = discrete_stats.intersection(data_stats.values()) if has_discrete: data_minus = {} for k, d in helper.data.items(): unit = 1.0 / (d.channel_width * d.spec_exposure) if data_stats[k] in {'cstat', 'pstat'}: data_minus[k] = (d.spec_counts - 1.0) * unit elif data_stats[k] == 'wstat': # Get the next small net spectrum values data_minus[k] = ( np.maximum( (d.spec_counts - 1.0) - d.back_ratio * d.back_counts, d.spec_counts - d.back_ratio * (d.back_counts + 1.0), ) * unit ) else: # chi2, pgstat data_minus[k] = d.ce y_miuns = idata['observed_data']['channels'].copy() y_miuns.data = np.hstack(list(data_minus.values())) loo_pit_minus = az.loo_pit( y=y_miuns, y_hat=y_hat, log_weights=self._psislw, ) loo_pit_minus = { name: loo_pit_minus.sel(channel=data.channel).values for name, data in helper.data.items() } else: loo_pit_minus = loo_pit self._pit = { name: (loo_pit_minus[name], loo_pit[name]) for name in loo_pit.keys() } return self._pit
[docs] class ParamsCovar(NamedTuple): """Covariance matrix of the model parameters.""" names: tuple[str, ...] """Parameter names.""" matrix: CovarMatrix """Covariance matrix."""
[docs] class ConfidenceInterval(NamedTuple): """Confidence interval result.""" mle: dict[str, float] """The maximum likelihood estimation.""" se: dict[str, float] """The standard errors of the MLE, calculated from Hessian matrix.""" intervals: dict[str, tuple[float, float]] """The confidence intervals.""" errors: dict[str, tuple[float, float]] """The confidence intervals in error form.""" cl: float """The confidence level.""" method: str """Method used to calculate the confidence interval.""" status: dict """Status of the calculation progress."""
[docs] class MLEFlux(NamedTuple): """The flux of the MLE model.""" emin: float """Minimum value of energy range.""" emax: float """Maximum value of energy range.""" energy: bool """Whether the flux is in energy flux. False for photon flux.""" mle: dict[str, Q] | dict[str, dict[str, Q]] """The maximum likelihood estimation of flux.""" se: dict[str, Q] | dict[str, dict[str, Q]] """The standard errors of MLE, calculated from Hessian matrix.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model flux.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model flux in error form.""" cl: float """The confidence level.""" method: str """Method used to calculate the confidence interval.""" status: dict """Status of the calculation progress."""
[docs] class MLELumin(NamedTuple): """The luminosity of the MLE model.""" emin_rest: float """Minimum value of rest-frame energy range.""" emax_rest: float """Maximum value of rest-frame energy range.""" z: float """Redshift of the source.""" cosmo: LambdaCDM """Cosmology model used to calculate luminosity.""" mle: dict[str, Q] | dict[str, dict[str, Q]] """The maximum likelihood estimation of luminosity.""" se: dict[str, Q] | dict[str, dict[str, Q]] """The standard errors of MLE, calculated from Hessian matrix.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model luminosity.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model luminosity in error form.""" cl: float """The confidence level.""" method: str """Method used to calculate the confidence interval.""" status: str """Status of the calculation progress."""
[docs] class MLEEiso(NamedTuple): """The isotropic emission energy of the MLE model.""" emin_rest: float """Minimum value of rest-frame energy range.""" emax_rest: float """Maximum value of rest-frame energy range.""" z: float """Redshift of the source.""" duration: float """Observed duration of the source.""" cosmo: LambdaCDM """Cosmology model used to calculate Eiso.""" mle: dict[str, Q] | dict[str, dict[str, Q]] """The maximum likelihood estimation of Eiso.""" se: dict[str, Q] | dict[str, dict[str, Q]] """The standard errors of MLE, calculated from Hessian matrix.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model Eiso.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The confidence intervals of the model Eiso in error form.""" cl: float """The confidence level.""" method: str """Method used to calculate the confidence interval.""" status: str """Status of the calculation progress."""
[docs] class BootstrapResult(NamedTuple): """Parametric bootstrap result.""" mle: dict """The maximum likelihood estimation.""" data: dict """Simulation data based on MLE.""" models: dict """Bootstrap models.""" params: dict """Bootstrap parameters.""" deviance: dict """Bootstrap deviance.""" p_value: dict """Model fitness :math:`p`-value.""" n: int """Numbers of bootstrap.""" n_valid: int """Numbers of valid bootstrap.""" seed: int """Seed of random number generator used in simulation."""
[docs] class CredibleInterval(NamedTuple): """Credible interval result.""" median: dict[str, float] """Median of the posterior distribution.""" intervals: dict[str, tuple[float, float]] """The credible intervals.""" errors: dict[str, tuple[float, float]] """The credible intervals in error form.""" cl: float """The credible level.""" method: str """Highest Density Interval (HDI), or equal tailed interval (ETI).""" dist: dict[str, JAXArray] """Posterior distribution."""
[docs] class PosteriorFlux(NamedTuple): """Posterior flux.""" emin: float """Minimum value of energy range.""" emax: float """Maximum value of energy range.""" energy: bool """Whether the flux is in energy flux. False for photon flux.""" mean: dict[str, Q] | dict[str, dict[str, Q]] """The posterior mean of the flux.""" std: dict[str, Q] | dict[str, dict[str, Q]] """The standard deviation of posterior distribution of flux.""" median: dict[str, Q] | dict[str, dict[str, Q]] """The posterior median of the flux.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the flux.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the flux in error form.""" cl: float """The credible level.""" method: str """Highest Density Interval (HDI), or equal tailed interval (ETI).""" dist: dict[str, Q] | dict[str, dict[str, Q]] """Posterior flux distribution."""
[docs] class PosteriorLumin(NamedTuple): """Posterior luminosity.""" emin_rest: float """Minimum value of rest-frame energy range.""" emax_rest: float """Maximum value of rest-frame energy range.""" z: float """Redshift of the source.""" cosmo: LambdaCDM """Cosmology model used to calculate luminosity.""" mean: dict[str, Q] | dict[str, dict[str, Q]] """The posterior mean of luminosity.""" std: dict[str, Q] | dict[str, dict[str, Q]] """The posterior standard deviation of the luminosity.""" median: dict[str, Q] | dict[str, dict[str, Q]] """The posterior median of the luminosity.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the luminosity.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the luminosity in error form.""" cl: float """The credible level.""" method: str """Highest Density Interval (HDI), or equal tailed interval (ETI).""" dist: dict[str, Q] | dict[str, dict[str, Q]] """Posterior distribution of luminosity."""
[docs] class PosteriorEiso(NamedTuple): """Posterior isotropic emission energy.""" emin_rest: float """Minimum value of rest-frame energy range.""" emax_rest: float """Maximum value of rest-frame energy range.""" z: float """Redshift of the source.""" duration: float """Observed duration of the source.""" cosmo: LambdaCDM """Cosmology model used to calculate Eiso.""" mean: dict[str, Q] | dict[str, dict[str, Q]] """The posterior mean of the Eiso.""" std: dict[str, Q] | dict[str, dict[str, Q]] """The posterior standard deviation of the Eiso.""" median: dict[str, Q] | dict[str, dict[str, Q]] r"""The posterior median of the Eiso.""" intervals: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the Eiso.""" errors: dict[str, tuple[Q, Q]] | dict[str, dict[str, tuple[Q, Q]]] """The credible intervals of the Eiso in error form.""" cl: float """The credible level.""" method: str """Highest Density Interval (HDI), or equal tailed interval (ETI).""" dist: dict[str, Q] | dict[str, dict[str, Q]] """Posterior distribution of Eiso."""
[docs] class PPCResult(NamedTuple): """Posterior predictive check result.""" params_rep: dict """Posterior of free parameters used to perform ppc.""" models_rep: dict """Models' values corresponding to `params_rep`.""" data: dict """Posterior predictive data.""" params_fit: dict """Best fit parameters of posterior predictive data.""" deviance: dict """Deviance of posterior predictive data and best fit models.""" models_fit: dict """Best fit models' values of posterior predictive data.""" p_value: dict """Posterior predictive :math:`p`-value.""" n: int """Numbers of posterior prediction.""" n_valid: int """Numbers of valid ppc.""" seed: int """Seed of random number generator used in simulation."""
def _format_result(result: dict, order: Sequence[str]) -> dict: """Sort the result and use float type.""" formatted = jax.tree.map(float, result) return {k: formatted[k] for k in order}