"""Likelihood functions."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, get_args
import jax
import jax.numpy as jnp
import numpyro
from jax import lax
from jax.experimental.sparse import BCSR
from jax.scipy.special import xlogy
from numpyro.distributions import Normal, Poisson
from numpyro.distributions.util import validate_sample
if TYPE_CHECKING:
from collections.abc import Callable
from elisa.data.base import FixedData
from elisa.util.typing import (
ArrayLike,
JAXArray,
ModelCompiledFn,
ParamNameValMapping,
)
# TODO:
# It should be noted that 'lstat' does not have long run coverage property
# for source estimation, which is probably due to the choice of conjugate
# prior of Poisson background data.
# 'lstat' will be included here with a proper prior at some point.
Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat']
_STATISTIC_OPTIONS: frozenset[str] = frozenset(get_args(Statistic))
_STATISTIC_SPEC_NORMAL: frozenset[str] = frozenset({'chi2'})
_STATISTIC_BACK_NORMAL: frozenset[str] = frozenset({'pgstat'})
_STATISTIC_WITH_BACK: frozenset[str] = frozenset({'pgstat', 'wstat'})
[docs]
def pgstat_background(
s: ArrayLike,
n: ArrayLike,
b_est: ArrayLike,
b_err: ArrayLike,
a: ArrayLike,
) -> JAXArray:
"""Optimized background for PG-statistics given estimate of source counts.
.. note::
The optimized background here is always non-negative, which differs
from XSPEC [1]_.
Parameters
----------
s : array_like
Estimate of source counts.
n : array_like
Observed counts (source and background).
b_est : array_like
Estimate of background counts.
b_err : array_like
Uncertainty of background counts.
a : float or array_like
Exposure ratio between source and background observations.
Returns
-------
JAXArray
The profile background.
References
----------
.. [1] `XSPEC Manual Appendix B: Statistics in XSPEC <https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/XSappendixStatistics.html>`__.
"""
variance = b_err * b_err
e = jnp.array(b_est - a * variance)
f = a * variance * n + e * s
c = a * e - s
d = jnp.sqrt(c * c + 4.0 * a * f)
b = jnp.where(
jnp.bitwise_or(jnp.greater_equal(e, 0.0), jnp.greater_equal(f, 0.0)),
jnp.where(jnp.greater(n, 0.0), (c + d) / (2 * a), e),
0.0,
)
return b
[docs]
def wstat_background(
s: ArrayLike,
n_on: ArrayLike,
n_off: ArrayLike,
a: ArrayLike,
) -> JAXArray:
"""Optimized background for W-statistics [1]_ given the estimate of source.
Parameters
----------
s : array_like
Estimate of source counts.
n_on : array_like
Observed source and background counts in "on" observation.
n_off : array_like
Observed background counts in "off" observation.
a : array_like
Exposure ratio between "on" and "off" observations.
Returns
-------
JAXArray
The profile background.
References
----------
.. [1] Wachter, K., Leach, R., & Kellogg, E. (1979). Parameter estimation
in X-ray astronomy using maximum likelihood. ApJ, 230, 274–287.
"""
c = a * (n_on + n_off) - (a + 1) * s
d = jnp.sqrt(c * c + 4 * a * (a + 1) * n_off * s)
b = jnp.where(
jnp.equal(n_on, 0),
n_off / (1 + a),
jnp.where(
jnp.equal(n_off, 0),
jnp.where(
jnp.less_equal(s, a / (a + 1) * n_on),
n_on / (1 + a) - s / a,
0.0,
),
(c + d) / (2 * a * (a + 1)),
),
)
return b
[docs]
class BetterNormal(Normal):
[docs]
@validate_sample
def log_prob(self, value):
value_scaled = (value - self.loc) / self.scale
return -0.5 * value_scaled * value_scaled
[docs]
class BetterPoisson(Poisson):
[docs]
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
if (
self.is_sparse
and not isinstance(value, jax.core.Tracer)
and jnp.size(value) > 1
):
broadcast = jnp.broadcast_to
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value))
rate = broadcast(self.rate, shape).reshape(-1)
nonzero = broadcast(jax.device_get(value) > 0, shape).reshape(-1)
value = broadcast(value, shape).reshape(-1)
sparse_value = value[nonzero]
sparse_rate = rate[nonzero]
tmp = xlogy(sparse_value, sparse_rate)
gof = xlogy(sparse_value, sparse_value) - sparse_value
return jnp.clip(
jnp.asarray(-rate, dtype=jnp.result_type(float))
.at[nonzero]
.add(tmp - gof)
.reshape(shape),
max=0.0,
)
else:
logp = xlogy(value, self.rate) - self.rate
gof = xlogy(value, value) - value
return jnp.clip(logp - gof, max=0.0)
def _get_resp_matrix(data: FixedData) -> JAXArray | BCSR:
if data.response_sparse:
return BCSR.from_scipy_sparse(data.sparse_matrix.T)
else:
return jnp.array(data.response_matrix.T, float)
[docs]
def chi2(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""S^2 statistic, Gaussian likelihood."""
name = str(data.name)
spec = jnp.array(data.net_counts, float)
error = jnp.array(data.net_errors, float)
photon_egrid = jnp.array(data.photon_egrid, float)
channel_width = jnp.array(data.channel_width, float)
resp_matrix = _get_resp_matrix(data)
area_scale = jnp.array(data.area_scale, float)
exposure = jnp.array(data.spec_exposure, float)
def likelihood(
params: ParamNameValMapping,
predictive: bool = False,
) -> None:
"""Gaussian likelihood defined via numpyro primitives."""
unfold = model(photon_egrid, params)
unfold = jnp.clip(unfold, min=1e-300, max=1e300)
source_rate = resp_matrix @ unfold * area_scale
numpyro.deterministic(name, source_rate / channel_width)
source_counts = source_rate * exposure
source_counts = jnp.clip(source_counts, min=1e-30, max=1e15)
spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec)
spec_model = numpyro.deterministic(f'{name}_Non_model', source_counts)
with numpyro.plate(f'{name}_plate', len(spec)):
dist_on = BetterNormal(spec_model, error)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else spec_data,
)
# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(spec_data)
)
numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)
return likelihood
[docs]
def cstat(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""C-statistic, Poisson likelihood."""
name = str(data.name)
spec = jnp.array(data.spec_counts, float)
photon_egrid = jnp.array(data.photon_egrid, float)
channel_width = jnp.array(data.channel_width, float)
resp_matrix = _get_resp_matrix(data)
area_scale = jnp.array(data.area_scale, float)
exposure = jnp.array(data.spec_exposure, float)
def likelihood(
params: ParamNameValMapping,
predictive: bool = False,
) -> None:
"""Poisson likelihood defined via numpyro primitives."""
unfold = model(photon_egrid, params)
unfold = jnp.clip(unfold, min=1e-300, max=1e300)
source_rate = resp_matrix @ unfold * area_scale
numpyro.deterministic(name, source_rate / channel_width)
source_counts = source_rate * exposure
source_counts = jnp.clip(source_counts, min=1e-30, max=1e15)
spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec)
spec_model = numpyro.deterministic(f'{name}_Non_model', source_counts)
with numpyro.plate(f'{name}_plate', len(spec)):
dist_on = BetterPoisson(spec_model)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else spec_data,
)
# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(spec_data)
)
numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)
return likelihood
[docs]
def pstat(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""P-statistic, Poisson likelihood for data with a known background."""
assert data.has_back, 'Data must have background'
name = str(data.name)
spec = jnp.array(data.spec_counts, float)
back = jnp.array(data.back_counts, float)
photon_egrid = jnp.array(data.photon_egrid, float)
channel_width = jnp.array(data.channel_width, float)
resp_matrix = _get_resp_matrix(data)
area_scale = jnp.array(data.area_scale, float)
exposure = jnp.array(data.spec_exposure, float)
back_ratio = jnp.array(data.back_ratio, float)
def likelihood(
params: ParamNameValMapping,
predictive: bool = False,
) -> None:
"""Poisson likelihood defined via numpyro primitives."""
unfold = model(photon_egrid, params)
unfold = jnp.clip(unfold, min=1e-300, max=1e300)
source_rate = resp_matrix @ unfold * area_scale
numpyro.deterministic(name, source_rate / channel_width)
model_counts = source_rate * exposure + back_ratio * back
model_counts = jnp.clip(model_counts, min=1e-30, max=1e15)
spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec)
spec_model = numpyro.deterministic(f'{name}_Non_model', model_counts)
with numpyro.plate(f'{name}_plate', len(spec_data)):
dist_on = BetterPoisson(spec_model)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else spec_data,
)
# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(spec_data)
)
numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)
return likelihood
[docs]
def pgstat(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""PG-statistic, Poisson likelihood for data and profile Gaussian
likelihood for background.
"""
assert data.has_back, 'Data must have background'
name = str(data.name)
spec = jnp.array(data.spec_counts, float)
back = jnp.array(data.back_counts, float)
back_error = jnp.array(data.back_errors, float)
photon_egrid = jnp.array(data.photon_egrid, float)
channel_width = jnp.array(data.channel_width, float)
resp_matrix = _get_resp_matrix(data)
area_scale = jnp.array(data.area_scale, float)
exposure = jnp.array(data.spec_exposure, float)
back_ratio = jnp.array(data.back_ratio, float)
def likelihood(params: ParamNameValMapping, predictive: bool = False):
"""Poisson and Gaussian likelihood defined via numpyro primitives."""
unfold = model(photon_egrid, params)
unfold = jnp.clip(unfold, min=1e-300, max=1e300)
source_rate = resp_matrix @ unfold * area_scale
numpyro.deterministic(name, source_rate / channel_width)
spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec)
back_data = numpyro.primitives.mutable(f'{name}_Noff_data', back)
source_counts = source_rate * exposure
source_counts = jnp.clip(source_counts, min=1e-30, max=1e15)
b = pgstat_background(
source_counts, spec_data, back_data, back_error, back_ratio
)
spec_model = source_counts + back_ratio * b
spec_model = numpyro.deterministic(f'{name}_Non_model', spec_model)
back_model = numpyro.deterministic(f'{name}_Noff_model', b)
with numpyro.plate(f'{name}_plate', len(spec_data)):
dist_on = BetterPoisson(spec_model)
dist_off = BetterNormal(back_model, back_error)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else spec_data,
)
numpyro.sample(
name=f'{name}_Noff',
fn=dist_off,
obs=None if predictive else back_data,
)
# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(spec_data)
)
loglike_off = numpyro.deterministic(
name=f'{name}_Noff_loglike', value=dist_off.log_prob(back_data)
)
numpyro.deterministic(
name=f'{name}_loglike', value=loglike_on + loglike_off
)
return likelihood
[docs]
def wstat(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""W-statistic, i.e. Poisson likelihood for data and profile Poisson
likelihood for background.
"""
assert data.has_back, 'Data must have background'
name = str(data.name)
spec = jnp.array(data.spec_counts, float)
back = jnp.array(data.back_counts, float)
photon_egrid = jnp.array(data.photon_egrid, float)
channel_width = jnp.array(data.channel_width, float)
resp_matrix = _get_resp_matrix(data)
area_scale = jnp.array(data.area_scale, float)
exposure = jnp.array(data.spec_exposure, float)
back_ratio = jnp.array(data.back_ratio, float)
def likelihood(params: ParamNameValMapping, predictive: bool = False):
"""Poisson and Poisson likelihood defined via numpyro primitives."""
unfold = model(photon_egrid, params)
unfold = jnp.clip(unfold, min=1e-300, max=1e300)
source_rate = resp_matrix @ unfold * area_scale
numpyro.deterministic(name, source_rate / channel_width)
spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec)
back_data = numpyro.primitives.mutable(f'{name}_Noff_data', back)
source_counts = source_rate * exposure
source_counts = jnp.clip(source_counts, min=1e-30, max=1e15)
b = wstat_background(source_counts, spec_data, back_data, back_ratio)
model_counts = source_counts + back_ratio * b
spec_model = numpyro.deterministic(f'{name}_Non_model', model_counts)
back_model = numpyro.deterministic(f'{name}_Noff_model', b)
with numpyro.plate(f'{name}_plate', len(spec_data)):
dist_on = BetterPoisson(spec_model)
dist_off = BetterPoisson(back_model)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else spec_data,
)
numpyro.sample(
name=f'{name}_Noff',
fn=dist_off,
obs=None if predictive else back_data,
)
# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(spec_data)
)
loglike_off = numpyro.deterministic(
name=f'{name}_Noff_loglike', value=dist_off.log_prob(back_data)
)
numpyro.deterministic(
name=f'{name}_loglike', value=loglike_on + loglike_off
)
return likelihood