"""Data classes for plotting."""
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import cache, wraps
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
import numpy as np
import scipy.stats as stats
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
from elisa.infer.likelihood import (
_STATISTIC_BACK_NORMAL,
_STATISTIC_SPEC_NORMAL,
_STATISTIC_WITH_BACK,
)
from elisa.plot.residuals import (
pearson_residuals,
pit_poisson,
pit_poisson_normal,
pit_poisson_poisson,
quantile_residuals_poisson,
)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Any, Literal
from xarray import DataArray
from elisa.infer.results import (
BootstrapResult,
FitResult,
MLEResult,
PosteriorResult,
PPCResult,
)
from elisa.util.typing import Array, NumPyArray
def _cache_method(bound_method: Callable) -> Callable:
"""Cache instance method."""
return cache(bound_method)
def _cache_method_with_check(
instance: Any, bound_method: Callable, check_fields: Sequence[str]
) -> Callable:
"""Cache instance method with computation dependency check."""
def get_id():
return {field: id(getattr(instance, field)) for field in check_fields}
cached_method = cache(bound_method)
old_id = get_id()
@wraps(bound_method)
def _(*args, **kwargs):
if (new_id := get_id()) != old_id:
cached_method.cache_clear()
old_id.update(new_id)
return cached_method(*args, **kwargs)
return _
def _get_cached_method_decorator(storage: list):
def decorator(method: Callable):
storage.append(method.__name__)
return method
return decorator
def _get_cached_method_with_check_decorator(
storage: list, check_fields: str | Sequence[str]
):
if isinstance(check_fields, str):
check_fields = [check_fields]
else:
check_fields = list(check_fields)
def decorator(method: Callable):
name = method.__name__
storage.append((name, check_fields))
return method
return decorator
[docs]
class PlotData(ABC):
"""Base class for data used in plotting."""
_cached_method: list[str]
_cached_method_with_check: list[tuple[str, list[str]]]
_unfolded_model_fn: dict[str, Callable]
_ph_egrid: NumPyArray | None = None
def __init__(self, name: str, result: FitResult, seed: int):
self.name = str(name)
self.result = result
self.seed = seed
self.data = result._helper.data[self.name]
self.statistic = result._helper.statistic[self.name]
for f in self._cached_method:
method = getattr(self, f)
setattr(self, f, _cache_method(method))
for f, fields in self._cached_method_with_check:
method = getattr(self, f)
setattr(self, f, _cache_method_with_check(self, method, fields))
model = self.result._helper.model[self.name]
self._unfolded_model_fn = {
'ne': jax.jit(lambda e, p: model.ne(e, p, comps=False)),
'ene': jax.jit(lambda e, p: model.ene(e, p, comps=False)),
'eene': jax.jit(lambda e, p: model.eene(e, p, comps=False)),
'ne_comps': jax.jit(lambda e, p: model.ne(e, p, comps=True)),
'ene_comps': jax.jit(lambda e, p: model.ene(e, p, comps=True)),
'eene_comps': jax.jit(lambda e, p: model.eene(e, p, comps=True)),
}
@property
def channel(self) -> NumPyArray:
return self.data.channel
@property
def channel_emin(self) -> NumPyArray:
return self.data.channel_emin
@property
def channel_emax(self) -> NumPyArray:
return self.data.channel_emax
@property
def channel_emid(self) -> NumPyArray:
return self.data.channel_emid
@property
def channel_width(self) -> NumPyArray:
return self.data.channel_width
@property
def channel_emean(self) -> NumPyArray:
return self.data.channel_emean
@property
def channel_errors(self) -> NumPyArray:
return self.data.channel_errors
@property
def photon_egrid(self) -> NumPyArray:
if self._ph_egrid is not None:
return self._ph_egrid
ph_egrid = self.data.photon_egrid
mask = np.bitwise_and(
self.channel_emin[0] <= ph_egrid, ph_egrid <= self.channel_emax[-1]
)
self._ph_egrid = ph_egrid[mask]
return self._ph_egrid
@property
def spec_counts(self) -> Array:
return self.data.spec_counts
@property
def spec_errors(self) -> Array:
return self.data.spec_errors
@property
def back_ratio(self) -> float | Array:
return self.data.back_ratio
@property
def back_counts(self) -> Array | None:
return self.data.back_counts
@property
def back_errors(self) -> Array | None:
return self.data.back_errors
@property
def net_counts(self) -> Array:
return self.data.net_counts
@property
def net_errors(self) -> Array:
return self.data.net_errors
@property
def ndata(self) -> int:
return len(self.data.channel)
@property
def ce_data(self) -> Array:
return self.data.ce
@property
def ce_errors(self) -> Array:
return self.data.ce_errors
@property
@abstractmethod
def ce_model(self) -> Array:
"""Point estimate of the folded source model."""
pass
[docs]
@abstractmethod
def ce_model_ci(self, cl: float = 0.683) -> Array | None:
"""Confidence/Credible intervals of the folded source model."""
pass
@property
def has_comps(self) -> bool:
return self.result._helper.model[self.name].has_comps
@property
def params_dist(self) -> dict[str, Array] | None:
return self.result._params_dist
def _unfolded_model(
self,
mtype: Literal['ne', 'ene', 'eene'],
egrid: Array,
params: dict,
comps: bool,
) -> Array | dict:
assert mtype in {'ne', 'ene', 'eene'}
fn = self._unfolded_model_fn[f'{mtype}_comps' if comps else mtype]
if len(np.shape(list(params.values())[0])) != 0:
devices = create_device_mesh((jax.local_device_count(),))
mesh = Mesh(devices, axis_names=('i',))
p = PartitionSpec()
pi = PartitionSpec('i')
fn = shard_map(
f=fn,
mesh=mesh,
in_specs=(p, pi),
out_specs=pi,
check_rep=False,
)
return jax.device_get(fn(egrid, params))
[docs]
@abstractmethod
def unfolded_model(
self,
mtype: Literal['ne', 'ene', 'eene'],
egrid: Array,
params: dict,
comps: bool,
cl: float | Array | None = None,
) -> Array | dict:
pass
[docs]
@abstractmethod
def pit(self) -> tuple:
"""Probability integral transform."""
pass
[docs]
@abstractmethod
def residuals(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None,
random_quantile: bool,
mle: bool,
) -> Array | tuple[Array, bool | Array, bool | Array]:
"""Residuals between the data and the fitted models."""
pass
[docs]
@abstractmethod
def residuals_sim(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None,
random_quantile: bool,
) -> Array | None:
"""Residuals bootstrap/ppc samples."""
pass
[docs]
@abstractmethod
def residuals_ci(
self,
rtype: Literal['rd', 'rp', 'rq'],
cl: float,
seed: int | None,
random_quantile: bool,
with_sign: bool,
) -> Array | None:
"""Confidence/Credible intervals of the residuals."""
pass
_cached_method = []
_cached_method_with_check = []
_to_cached_method = _get_cached_method_decorator(_cached_method)
_to_cached_method_with_check = _get_cached_method_with_check_decorator(
_cached_method_with_check, 'boot'
)
[docs]
class MLEPlotData(PlotData):
result: MLEResult
_cached_method = _cached_method
_cached_method_with_check = _cached_method_with_check
@property
def boot(self) -> BootstrapResult:
return self.result._boot
@property
def params_mle(self) -> dict[str, Array]:
return {k: v[0] for k, v in self.result._mle.items()}
[docs]
def get_model_mle(self, name: str) -> Array:
return self.result._model_values[name]
[docs]
def get_model_boot(self, name: str) -> Array | None:
boot = self.boot
if boot is None:
return None
else:
return boot.models[name]
[docs]
def get_data_boot(self, name: str) -> Array | None:
boot = self.boot
if boot is None:
return None
else:
return boot.data[name]
@property
def ce_model(self) -> Array:
return self.get_model_mle(self.name)
[docs]
@_to_cached_method_with_check
def ce_model_ci(self, cl: float = 0.683) -> Array | None:
if self.boot is None:
return None
assert 0.0 < cl < 1.0
ci = np.quantile(
self.get_model_boot(self.name),
q=0.5 + cl * np.array([-0.5, 0.5]),
axis=0,
)
return ci
[docs]
def unfolded_model(
self,
mtype: Literal['ne', 'ene', 'eene'],
egrid: Array | None,
params: dict | None,
comps: bool,
cl: float | Array | None = None,
) -> tuple[Array | dict, Array | dict | None]:
assert mtype in {'ne', 'ene', 'eene'}
if cl is not None:
cl = np.atleast_1d(cl).astype(float)
assert np.all(0.0 < cl) and np.all(cl < 1.0)
params = {} if params is None else dict(params)
comps = comps and self.has_comps
egrid = jnp.asarray(egrid, float)
params_mle = self.params_mle | params
model_mle = self._unfolded_model(mtype, egrid, params_mle, comps)
params_boot = self.params_dist
if cl is None or params_boot is None:
return model_mle, None
else:
n = [i.size for i in params_boot.values()][0]
if params:
params = {k: jnp.full(n, v) for k, v in params.items()}
params_boot = params_boot | params
model_boot = self._unfolded_model(mtype, egrid, params_boot, comps)
q = 0.5 + cl[:, None] * np.array([-0.5, 0.5])
if comps:
ci = {
k: np.quantile(v, q, axis=0) for k, v in model_boot.items()
}
else:
ci = np.quantile(model_boot, q, axis=0)
return model_mle, ci
@property
def sign(self) -> dict[str, Array | None]:
"""Sign of the difference between the data and the fitted models."""
return {'mle': self._sign_mle(), 'boot': self._sign_boot()}
@_to_cached_method
def _sign_mle(self) -> Array:
return np.where(self.ce_data >= self.ce_model, 1.0, -1.0)
@_to_cached_method_with_check
def _sign_boot(self) -> Array | None:
boot = self.get_model_boot(self.name)
if boot is not None:
boot = np.where(self.get_data_boot(self.name) >= boot, 1.0, -1.0)
return boot
[docs]
def model(
self,
on_off: Literal['on', 'off'],
mtype: Literal['mle', 'boot'],
) -> Array | None:
"""Point estimate or bootstrap models of the on/off measurement."""
assert on_off in {'on', 'off'}
assert mtype in {'mle', 'boot'}
if (on_off == 'off') and (self.statistic not in _STATISTIC_WITH_BACK):
return None
name = f'{self.name}_N{on_off}_model'
return getattr(self, f'get_model_{mtype}')(name)
[docs]
def deviance(self, rtype: Literal['mle', 'boot']) -> Array | None:
"""MLE and bootstrap deviance."""
if rtype == 'mle':
return self.result._deviance['point'][self.name]
elif rtype == 'boot':
if self.boot is not None:
return self.boot.deviance['point'][self.name]
else:
return None
else:
raise ValueError(f'unknown deviance type: {rtype}')
@property
def _nsim(self) -> int:
return 10000
[docs]
@_to_cached_method
def pit(self) -> tuple[Array, Array]:
stat = self.statistic
if stat in _STATISTIC_SPEC_NORMAL:
on_data = self.net_counts
else:
on_data = self.spec_counts
on_model = self.model('on', 'mle')
if stat in _STATISTIC_SPEC_NORMAL: # chi2
pit = stats.norm.cdf((on_data - on_model) / self.net_errors)
return pit, pit
if stat in _STATISTIC_WITH_BACK:
off_data = self.back_counts
off_model = self.model('off', 'mle')
if stat in _STATISTIC_BACK_NORMAL: # pgstat
pit = pit_poisson_normal(
k=on_data,
lam=on_model,
v=off_data,
mu=off_model,
sigma=self.back_errors,
ratio=self.back_ratio,
seed=self.seed + 1,
nsim=self._nsim,
)
return pit, pit
else: # wstat
return pit_poisson_poisson(
k1=on_data,
k2=off_data,
lam1=on_model,
lam2=off_model,
ratio=self.data.back_ratio,
minus=True,
seed=self.seed + 1,
nsim=self._nsim,
)
else: # cstat, or pstat
return pit_poisson(k=on_data, lam=on_model, minus=True)
[docs]
def residuals(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None = None,
random_quantile: bool = True,
mle: bool = True,
) -> Array | tuple[Array, bool | Array, bool | Array]:
if rtype == 'rd':
return self.deviance_residuals_mle()
elif rtype == 'rp':
return self.pearson_residuals_mle()
elif rtype == 'rq':
seed = self.seed if seed is None else int(seed)
return self.quantile_residuals_mle(seed, random_quantile)
else:
raise NotImplementedError(f'{rtype} residual')
[docs]
def residuals_sim(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None = None,
random_quantile: bool = True,
) -> Array | None:
if self.boot is None or rtype == 'rq':
return None
if rtype == 'rd':
r = self.deviance_residuals_boot()
elif rtype == 'rp':
r = self.pearson_residuals_boot()
else:
raise NotImplementedError(f'{rtype} residual')
return r
[docs]
def residuals_ci(
self,
rtype: Literal['rd', 'rp', 'rq'],
cl: float = 0.683,
seed: int | None = None,
random_quantile: bool = True,
with_sign: bool = False,
) -> Array | None:
if self.boot is None or rtype == 'rq':
return None
assert 0 < cl < 1
r = self.residuals_sim(rtype, seed, random_quantile)
if with_sign:
return np.quantile(r, q=0.5 + cl * np.array([-0.5, 0.5]), axis=0)
else:
q = np.quantile(np.abs(r), q=cl, axis=0)
return np.row_stack([-q, q])
[docs]
@_to_cached_method
def deviance_residuals_mle(self) -> Array:
return self._deviance_residuals('mle')
[docs]
@_to_cached_method_with_check
def deviance_residuals_boot(self) -> Array | None:
return self._deviance_residuals('boot')
def _deviance_residuals(
self, rtype: Literal['mle', 'boot']
) -> Array | None:
if rtype == 'boot' and self.boot is None:
return None
# NB: if background is present, then this assumes the background is
# being profiled out, so that each src & bkg data pair has ~1 dof
return self.sign[rtype] * np.sqrt(self.deviance(rtype))
[docs]
@_to_cached_method
def pearson_residuals_mle(self) -> Array:
return self._pearson_residuals('mle')
[docs]
@_to_cached_method_with_check
def pearson_residuals_boot(self) -> Array | None:
return self._pearson_residuals('boot')
def _pearson_residuals(
self, rtype: Literal['mle', 'boot']
) -> Array | None:
if rtype == 'boot' and self.boot is None:
return None
stat = self.statistic
if rtype == 'mle':
if stat in _STATISTIC_SPEC_NORMAL:
on_data = self.net_counts
else:
on_data = self.spec_counts
else:
on_data = self.get_data_boot(f'{self.name}_Non')
if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
else:
std = None
r = pearson_residuals(on_data, self.model('on', rtype), std)
if stat in _STATISTIC_WITH_BACK:
if rtype == 'mle':
off_data = self.back_counts
else:
off_data = self.get_data_boot(f'{self.name}_Noff')
if self.statistic in _STATISTIC_BACK_NORMAL:
std = self.back_errors
else:
std = None
r_b = pearson_residuals(off_data, self.model('off', rtype), std)
# NB: this assumes the background is being profiled out,
# so that each src & bkg data pair has ~1 dof
r = self.sign[rtype] * np.sqrt(r * r + r_b * r_b)
return r
[docs]
def quantile_residuals_mle(
self, seed: int, random: bool
) -> tuple[Array, Array | bool, Array | bool]:
pit_minus, pit = self.pit()
if random:
pit = np.random.default_rng(seed).uniform(pit_minus, pit)
r = stats.norm.ppf(pit)
lower = upper = False
stat = self.statistic
if stat == 'chi2':
mask = (pit == 0.0) | (pit == 1.0)
if np.any(mask):
on_data = self.net_counts[mask]
on_model = self.model('on', 'mle')[mask]
error = self.net_errors[mask]
r[mask] = (on_data - on_model) / error
elif stat in {'cstat', 'pstat'}:
mask = (pit == 0.0) | (pit == 1.0)
if np.any(mask):
on_data = self.spec_counts[mask]
on_model = self.model('on', 'mle')[mask]
r[mask] = quantile_residuals_poisson(
on_data,
on_model,
keep_sign=not random,
random=random,
seed=seed,
)
elif stat in {'pgstat', 'wstat'}:
upper_mask = pit == 0.0
if np.any(upper_mask):
r[upper_mask] = stats.norm.ppf(1.0 / self._nsim)
upper = np.full(r.shape, False)
upper[upper_mask] = True
lower_mask = pit == 1.0
if np.any(lower_mask):
r[lower_mask] = stats.norm.ppf(1.0 - 1.0 / self._nsim)
lower = np.full(r.shape, False)
lower[lower_mask] = True
return r, lower, upper
# clean up helpers
del (
_cached_method,
_cached_method_with_check,
_to_cached_method,
_to_cached_method_with_check,
)
_cached_method = []
_cached_method_with_check = []
_to_cached_method = _get_cached_method_decorator(_cached_method)
_to_cached_method_with_check = _get_cached_method_with_check_decorator(
_cached_method_with_check, 'ppc'
)
[docs]
class PosteriorPlotData(PlotData):
result: PosteriorResult
_cached_method = _cached_method
_cached_method_with_check = _cached_method_with_check
@property
def params(self) -> dict[str, Array]:
return self.result._params_dist
@property
def ppc(self) -> PPCResult | None:
return self.result._ppc
[docs]
@_to_cached_method
def get_model_median(self, name: str) -> Array:
posterior = self.result.idata['posterior'][name]
return posterior.median(dim=('chain', 'draw')).values
[docs]
@_to_cached_method
def get_model_loo(self, name: str) -> Array:
posterior = self.result.idata['posterior'][name]
posterior = posterior.stack(__sample__=('chain', 'draw'))
return self.result._loo_expectation(posterior, self.name).values
[docs]
@_to_cached_method
def get_model_posterior(self, name: str) -> DataArray:
posterior = self.result.idata['posterior'][name]
# return shape (n_samples, n_channel)
return posterior.stack(__sample__=('chain', 'draw')).T
[docs]
def get_model_ppc(self, name: str) -> Array | None:
if self.ppc is None:
return None
else:
return self.ppc.models_fit[name]
[docs]
def get_model_mle(self, name: str) -> Array | None:
mle = self.result._mle
if mle is None:
return None
else:
return mle['models'][name]
@property
def ce_model(self) -> Array:
return self.get_model_median(self.name)
[docs]
@_to_cached_method
def ce_model_ci(self, cl: float = 0.683) -> Array:
assert 0.0 < cl < 1.0
return np.quantile(
self.get_model_posterior(self.name).values,
q=0.5 + cl * np.array([-0.5, 0.5]),
axis=0,
)
[docs]
def unfolded_model(
self,
mtype: Literal['ne', 'ene', 'eene'],
egrid: Array | None,
params: dict | None,
comps: bool,
cl: float | Array | None = None,
) -> tuple[Array | dict, Array | dict | None]:
assert mtype in {'ne', 'ene', 'eene'}
if cl is not None:
cl = np.atleast_1d(cl).astype(float)
assert np.all(0.0 < cl) and np.all(cl < 1.0)
params = {} if params is None else dict(params)
comps = comps and self.has_comps
egrid = jnp.asarray(egrid, float)
post_params = self.params
n = [i.size for i in post_params.values()][0]
params = {k: jnp.full(n, v) for k, v in params.items()}
post_params = post_params | params
models = self._unfolded_model(mtype, egrid, post_params, comps)
if comps:
model = {k: np.median(v, axis=0) for k, v in models.items()}
if cl is None:
return model, None
q = 0.5 + cl[:, None] * np.array([-0.5, 0.5])
ci = {k: np.quantile(v, q, axis=0) for k, v in models.items()}
return model, ci
else:
model = np.median(models, axis=0)
q = 0.5 + cl[:, None] * np.array([-0.5, 0.5])
ci = np.quantile(models, q, axis=0)
return model, ci
@property
def sign(self) -> dict[str, Array | None]:
"""Sign of the difference between the data and the fitted models."""
return {
'posterior': self._sign_posterior(),
'loo': self._sign_loo(),
'median': self._sign_median(),
'mle': self._sign_mle(),
'ppc': self._sign_ppc(),
}
@_to_cached_method
def _sign_posterior(self) -> Array:
ce_posterior = self.get_model_posterior(self.name)
return np.where(self.ce_data >= ce_posterior, 1.0, -1.0)
@_to_cached_method
def _sign_loo(self) -> Array:
ce_loo = self.get_model_loo(self.name)
return np.where(self.ce_data >= ce_loo, 1.0, -1.0)
@_to_cached_method
def _sign_median(self) -> Array:
ce_median = self.get_model_median(self.name)
return np.where(self.ce_data >= ce_median, 1.0, -1.0)
@_to_cached_method_with_check
def _sign_mle(self) -> Array | None:
if self.ppc is None:
return None
ce_mle = self.get_model_mle(self.name)
return np.where(self.ce_data >= ce_mle, 1.0, -1.0)
@_to_cached_method_with_check
def _sign_ppc(self) -> Array | None:
if self.ppc is None:
return None
ce_ppc = self.get_model_ppc(self.name)
return np.where(self.ppc.data[self.name] >= ce_ppc, 1.0, -1.0)
[docs]
def model(
self,
on_off: Literal['on', 'off'],
mtype: Literal['posterior', 'loo', 'median', 'mle', 'ppc'],
) -> Array | None:
assert on_off in {'on', 'off'}
assert mtype in {'posterior', 'loo', 'median', 'mle', 'ppc'}
if (on_off == 'off') and (self.statistic not in _STATISTIC_WITH_BACK):
return None
name = f'{self.name}_N{on_off}_model'
return getattr(self, f'get_model_{mtype}')(name)
[docs]
def deviance(
self,
rtype: Literal['posterior', 'loo', 'mle', 'ppc'],
) -> DataArray | None:
"""Median, MLE, and ppc deviance."""
if rtype == 'posterior':
loglike = self.result.idata['log_likelihood'][self.name]
return -2.0 * loglike.stack(__sample__=('chain', 'draw')).T
elif rtype == 'loo':
loglike = self.result.idata['log_likelihood'][self.name]
deviance = -2.0 * loglike.stack(__sample__=('chain', 'draw')).T
return self.result._loo_expectation(deviance, self.name)
elif rtype == 'mle':
if self.result._mle is not None:
return self.result._mle['deviance']['point'][self.name]
else:
return None
elif rtype == 'ppc':
if self.ppc is not None:
return self.ppc.deviance['point'][self.name]
else:
return None
else:
raise ValueError(f'unknown deviance type: {rtype}')
[docs]
def pit(self) -> tuple:
return self.result._loo_pit[self.name]
[docs]
def residuals(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None = None,
random_quantile: bool = True,
mle: bool = False,
) -> Array | tuple[Array, bool | Array, bool | Array]:
assert rtype in {'rd', 'rp', 'rq'}
if rtype == 'rq':
seed = self.seed if seed is None else int(seed)
return self.quantile_residuals(seed, random_quantile)
else:
point_type = 'mle' if mle else 'loo'
rname = 'deviance' if rtype == 'rd' else 'pearson'
return getattr(self, f'{rname}_residuals_{point_type}')()
[docs]
def residuals_sim(
self,
rtype: Literal['rd', 'rp', 'rq'],
seed: int | None = None,
random_quantile: bool = True,
) -> Array | None:
if self.ppc is None or rtype == 'rq':
return None
if rtype == 'rd':
r = self.deviance_residuals_ppc()
elif rtype == 'rp':
r = self.pearson_residuals_ppc()
else:
raise NotImplementedError(f'{rtype} residual')
return r
[docs]
def residuals_ci(
self,
rtype: Literal['rd', 'rp', 'rq'],
cl: float = 0.683,
seed: int | None = None,
random_quantile: bool = True,
with_sign: bool = False,
) -> Array | None:
if self.ppc is None or rtype == 'rq':
return None
assert 0 < cl < 1
r = self.residuals_sim(rtype, seed, random_quantile)
if with_sign:
return np.quantile(r, q=0.5 + cl * np.array([-0.5, 0.5]), axis=0)
else:
q = np.quantile(np.abs(r), q=cl, axis=0)
return np.row_stack([-q, q])
[docs]
@_to_cached_method
def deviance_residuals_loo(self) -> Array:
return self._deviance_residuals('loo')
[docs]
@_to_cached_method
def deviance_residuals_median(self) -> Array:
return np.median(self._deviance_residuals('posterior'), axis=0)
[docs]
@_to_cached_method_with_check
def deviance_residuals_mle(self) -> Array:
return self._deviance_residuals('mle')
[docs]
@_to_cached_method_with_check
def deviance_residuals_ppc(self) -> Array | None:
if self.ppc is None:
return None
return self._deviance_residuals('ppc')
def _deviance_residuals(
self, rtype: Literal['loo', 'posterior', 'mle', 'ppc']
) -> Array | None:
if rtype in ['mle', 'ppc'] and self.ppc is None:
return None
# NB: if background is present, then this assumes the background is
# being profiled out, so that each src & bkg data pair has ~1 dof
return self.sign[rtype] * np.sqrt(self.deviance(rtype))
[docs]
@_to_cached_method
def pearson_residuals_loo(self) -> Array:
return self._pearson_residuals('loo')
[docs]
@_to_cached_method
def pearson_residuals_median(self) -> Array:
return np.median(self._pearson_residuals('posterior'), axis=0)
[docs]
@_to_cached_method_with_check
def pearson_residuals_mle(self) -> Array:
return self._pearson_residuals('mle')
[docs]
@_to_cached_method_with_check
def pearson_residuals_ppc(self) -> Array | None:
if self.ppc is None:
return None
return self._pearson_residuals('ppc')
def _pearson_residuals(
self, rtype: Literal['posterior', 'loo', 'mle', 'ppc']
) -> Array | None:
if rtype in ['mle', 'ppc'] and self.ppc is None:
return None
stat = self.statistic
mtype = 'posterior' if rtype == 'loo' else rtype
if rtype in {'posterior', 'loo', 'mle'}:
if stat in _STATISTIC_SPEC_NORMAL:
on_data = self.net_counts
else:
on_data = self.spec_counts
else:
on_data = self.ppc.data[f'{self.name}_Non']
on_model = self.model('on', mtype)
if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
else:
std = None
r = pearson_residuals(on_data, on_model, std)
if stat in _STATISTIC_WITH_BACK:
if rtype in {'posterior', 'loo', 'mle'}:
off_data = self.back_counts
else:
off_data = self.ppc.data[f'{self.name}_Noff']
off_model = self.model('off', mtype)
if stat in _STATISTIC_BACK_NORMAL:
std = self.back_errors
else:
std = None
r_b = pearson_residuals(off_data, off_model, std)
# NB: this assumes the background is being profiled out,
# so that each src & bkg data pair has ~1 dof
r = self.sign[rtype] * np.sqrt(r * r + r_b * r_b)
if rtype == 'loo':
r = self.result._loo_expectation(np.abs(r), self.name)
r *= self.sign[rtype]
return r
[docs]
def quantile_residuals(
self, seed: int, random: bool
) -> tuple[Array, Array | bool, Array | bool]:
pit_minus, pit = self.pit()
if random:
pit = np.random.default_rng(seed).uniform(pit_minus, pit)
r = stats.norm.ppf(pit)
# Assume the posterior prediction is nchan * ndraw times
nchain = len(self.result.idata['posterior']['chain'])
ndraw = len(self.result.idata['posterior']['draw'])
nsim = nchain * ndraw
lower = upper = False
upper_mask = pit == 0.0
if np.any(upper_mask):
r[upper_mask] = stats.norm.ppf(1.0 / nsim)
upper = np.full(r.shape, False)
upper[upper_mask] = True
lower_mask = pit == 1.0
if np.any(lower_mask):
r[lower_mask] = stats.norm.ppf(1.0 - 1.0 / nsim)
lower = np.full(r.shape, False)
lower[lower_mask] = True
return r, lower, upper
# clean up helpers
del (
_cached_method,
_cached_method_with_check,
_to_cached_method,
_to_cached_method_with_check,
)