"""Visualize fit and analysis results."""
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import arviz as az
import jax
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from elisa.infer.helper import check_params
from elisa.plot.data import MLEPlotData, PosteriorPlotData
from elisa.plot.misc import plot_corner, plot_trace
from elisa.plot.scale import LinLogScale, get_scale
from elisa.plot.util import get_colors, get_markers
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from matplotlib.pyplot import Axes, Figure
from elisa.infer.results import FitResult, MLEResult, PosteriorResult
from elisa.plot.data import PlotData
from elisa.util.typing import Array, NumPyArray
def _plot_step(
ax: Axes, x_left: Array, x_right: Array, y: Array, **step_kwargs
) -> None:
assert len(y) == len(x_left) == len(x_right)
step_kwargs['where'] = 'post'
mask = x_left[1:] != x_right[:-1]
idx = np.insert(np.flatnonzero(mask) + 1, 0, 0)
idx = np.append(idx, len(y))
for i in range(len(idx) - 1):
i_slice = slice(idx[i], idx[i + 1])
x_slice = np.append(x_left[i_slice], x_right[i_slice][-1])
y_slice = y[i_slice]
y_slice = np.append(y_slice, y_slice[-1])
ax.step(x_slice, y_slice, **step_kwargs)
def _plot_ribbon(
ax,
x_left: Array,
x_right: Array,
y_ribbons: Sequence[Array],
**ribbon_kwargs,
) -> None:
y_ribbons = list(map(np.asarray, y_ribbons))
shape = y_ribbons[0].shape
assert len(shape) == 2 and shape[0] == 2
assert shape[1] == len(x_left) == len(x_right)
assert all(ribbon.shape == shape for ribbon in y_ribbons)
ribbon_kwargs['step'] = 'post'
mask = x_left[1:] != x_right[:-1]
idx = np.insert(np.flatnonzero(mask) + 1, 0, 0)
idx = np.append(idx, shape[1])
for i in range(len(idx) - 1):
i_slice = slice(idx[i], idx[i + 1])
x_slice = np.append(x_left[i_slice], x_right[i_slice][-1])
for ribbon in y_ribbons:
lower = ribbon[0]
lower_slice = lower[i_slice]
lower_slice = np.append(lower_slice, lower_slice[-1])
upper = ribbon[1]
upper_slice = upper[i_slice]
upper_slice = np.append(upper_slice, upper_slice[-1])
ax.fill_between(x_slice, lower_slice, upper_slice, **ribbon_kwargs)
def _adjust_log_range(
ax: Axes,
axis: Literal['x', 'y'] = 'y',
octave: int = 4,
) -> None:
octave = round(octave)
assert octave > 0
if axis == 'y':
vmin, vmax = ax.dataLim.intervaly
set_lim = ax.set_ylim
else:
vmin, vmax = ax.dataLim.intervalx
set_lim = ax.set_xlim
if np.log10(vmax / max(1e-30, vmin)) > int(octave):
vmin = vmax / 10**octave
set_lim(
np.power(10, np.log10(vmin) - 0.05), np.power(10, np.log10(vmax) + 0.1)
)
_FV_KEV_TO_JY = 413566.7696923858
def _scale_fv_to_jy(values):
if values is None:
return None
if isinstance(values, dict):
return {k: v * _FV_KEV_TO_JY for k, v in values.items()}
return values * _FV_KEV_TO_JY
def _get_qq(
q: NumPyArray,
detrend: bool,
cl: float,
qsim: NumPyArray | None = None,
) -> tuple[NumPyArray, ...]:
"""Get the Q-Q and pointwise confidence/credible interval.
References
----------
.. [1] doi:10.1080/00031305.2013.847865
"""
# https://stats.stackexchange.com/a/9007
# https://stats.stackexchange.com/a/152834
alpha = np.pi / 8 # 3/8 is also ok
n = len(q)
theor = stats.norm.ppf((np.arange(1, n + 1) - alpha) / (n - 2 * alpha + 1))
q = np.sort(q)
if qsim is not None:
line, lower, upper = np.quantile(
np.sort(qsim, axis=1),
q=[0.5, 0.5 - 0.5 * cl, 0.5 + 0.5 * cl],
axis=0,
)
else:
line = np.array(theor)
grid = np.arange(1, n + 1)
lower = stats.beta.ppf(0.5 - cl * 0.5, grid, n + 1 - grid)
upper = stats.beta.ppf(0.5 + cl * 0.5, grid, n + 1 - grid)
lower = stats.norm.ppf(lower)
upper = stats.norm.ppf(upper)
if detrend:
q -= theor
line -= theor
lower -= theor
upper -= theor
return theor, q, line, lower, upper
def _get_pit_ecdf(
pit: NumPyArray,
cl: float,
detrend: bool,
) -> tuple[NumPyArray, ...]:
"""Get the empirical CDF of PIT and pointwise confidence/credible interval.
References
----------
.. [1] doi:10.1007/s11222-022-10090-6
"""
n = len(pit)
# See ref [1] for the following
scaled_rank = np.linspace(0.0, 1.0, n + 1)
pit_ecdf = np.count_nonzero(pit <= scaled_rank[:, None], axis=1) / n
lower, upper = stats.binom.interval(cl, n, scaled_rank)
lower = lower / n
upper = upper / n
if detrend:
line = np.zeros_like(scaled_rank)
lower -= scaled_rank
upper -= scaled_rank
pit_ecdf -= scaled_rank
else:
line = scaled_rank
return scaled_rank, pit_ecdf, line, lower, upper
# x = np.hstack([0.0, np.sort(pit), 1.0])
# pit_ecdf = np.hstack([0.0, np.arange(n) / n, 1.0])
# lower, upper = stats.binom.interval(cl, n, x)
# lower = lower / n
# upper = upper / n
#
# if detrend:
# pit_ecdf -= x
# lower -= x
# upper -= x
# line = np.zeros_like(x)
# else:
# line = x
#
# return x, pit_ecdf, line, lower, upper
# def _get_pit_pdf(pit_intervals: NumPyArray) -> NumPyArray:
# """Get the pdf of PIT.
#
# References
# ----------
# .. [1] doi:10.1111/j.1541-0420.2009.01191.x
# """
# assert len(pit_intervals.shape) == 2 and pit_intervals.shape[1] == 2
#
# grid = np.unique(pit_intervals)
# if grid[0] > 0.0:
# grid = np.insert(grid, 0, 0)
# if grid[-1] < 1.0:
# grid = np.append(grid, 1.0)
#
# n = len(pit_intervals)
# mask = pit_intervals[:, 0] != pit_intervals[:, 1]
# cover_mask = np.bitwise_and(
# pit_intervals[:, :1] <= grid[:-1],
# grid[1:] <= pit_intervals[:, 1:],
# )
# pdf = np.zeros((n, len(grid) - 1))
# pdf[cover_mask] = np.repeat(
# 1.0 / (pit_intervals[mask, 1] - pit_intervals[mask, 0]),
# np.count_nonzero(cover_mask[mask], axis=1),
# )
# idx = np.clip(grid.searchsorted(pit_intervals[~mask, 0]) - 1, 0, None)
# pdf[~mask, idx] = 1.0 / (grid[idx + 1] - grid[idx])
# return pdf.mean(0)
[docs]
class PlotConfig:
"""Plotting configuration."""
_YLABLES = {
'ce': r'$C_E\ \mathrm{[counts\ s^{-1}\ keV^{-1}]}$',
'ne': r'$N_E\ \mathrm{[ph\ cm^{-2}\ s^{-1}\ keV^{-1}]}$',
'ene': r'$E N_E\ \mathrm{[erg\ cm^{-2}\ s^{-1}\ keV^{-1}]}$',
'Fv': r'$F_{\nu}\ \mathrm{[erg\ cm^{-2}\ s^{-1}\ keV^{-1}]}$',
'FvJy': r'$F_{\nu}\ \mathrm{[Jy]}$',
'eene': r'$E^2 N_E\ \mathrm{[erg\ cm^{-2}\ s^{-1}]}$',
'vFv': r'$\nu F_{\nu}\ \mathrm{[erg\ cm^{-2}\ s^{-1}]}$',
'rd': r'$r_D\ [\mathrm{\sigma}]$',
'rp': r'$r_\mathrm{P}\ [\mathrm{\sigma}]$',
'rq': r'$r_Q\ [\mathrm{\sigma}]$',
}
def __init__(
self,
alpha: float = 0.8,
palette: Any = 'colorblind',
xscale: Literal['linear', 'log'] = 'log',
yscale: Literal['linear', 'log', 'linlog'] = 'linlog',
lin_frac: float = 0.15,
cl: tuple[float, ...] = (0.683, 0.95),
residuals: Literal['rd', 'rp', 'rq'] = 'rq',
random_quantile: bool = False,
mark_outlier_residuals: bool = False,
residuals_ci_with_sign: bool = True,
fill_residuals_ci: bool = True,
plot_comps: bool = False,
seed: int | None = None,
):
self.alpha = alpha
self.palette = palette
self.xscale = xscale
self.yscale = yscale
self.lin_frac = lin_frac
self.cl = cl
self.residuals = residuals
self.random_quantile = random_quantile
self.mark_outlier_residuals = mark_outlier_residuals
self.residuals_ci_with_sign = residuals_ci_with_sign
self.fill_residuals_ci = fill_residuals_ci
self.plot_comps = plot_comps
self.seed = seed
@property
def alpha(self) -> float:
"""Transparency of colors."""
return self._alpha
@alpha.setter
def alpha(self, alpha: float):
alpha = float(alpha)
if not 0.0 < alpha <= 1.0:
raise ValueError('alpha must be in (0, 1]')
self._alpha = alpha
@property
def palette(self) -> Any:
"""Color palettes, see [1]_ for details.
References
----------
.. [1] `seaborn tutorial: Choosing color palettes
<https://seaborn.pydata.org/tutorial/color_palettes.html>`__
"""
return self._palette
@palette.setter
def palette(self, palette: Any):
self._palette = palette
@property
def xscale(self) -> Literal['linear', 'log']:
"""X-axis scale of spectral plot.
Should be ``'linear'``, or ``'log'``.
"""
return self._xscale
@xscale.setter
def xscale(self, xscale: Literal['linear', 'log']):
if xscale not in {'linear', 'log'}:
raise ValueError('xscale must be "linear" or "log"')
self._xscale = xscale
@property
def yscale(self) -> Literal['linear', 'log', 'linlog']:
"""X-axis scale of spectral plot.
Should be ``'linear'``, ``'log'``, or ``'linlog'``.
"""
return self._yscale
@yscale.setter
def yscale(self, yscale: Literal['linear', 'log', 'linlog']):
if yscale not in {'linear', 'log', 'linlog'}:
raise ValueError('yscale must be "linear", "log", or "linlog"')
self._yscale = yscale
@property
def lin_frac(self) -> float:
"""Linear fraction of the ``linlog`` plot."""
return self._lin_frac
@lin_frac.setter
def lin_frac(self, lin_frac: float):
lin_frac = float(lin_frac)
if not 0.0 < lin_frac <= 0.5:
raise ValueError('lin_frac must be in (0, 0.5]')
self._lin_frac = lin_frac
@property
def cl(self) -> NumPyArray:
"""Confidence/Credible level."""
return self._cl
@cl.setter
def cl(self, cl: float | Sequence[float]):
cl = np.sort(np.atleast_1d(cl)).astype(float)
for c in cl:
if not 0.0 < c < 1.0:
raise ValueError('cl must be in (0, 1)')
self._cl = cl
@property
def residuals(self) -> Literal['rd', 'rp', 'rq']:
"""Default type of residual plot."""
return self._residuals
@residuals.setter
def residuals(self, residuals: Literal['rd', 'rp', 'rq']):
if residuals not in {'rd', 'rp', 'rq'}:
raise ValueError(
'residuals type must be "rd" (deviance), "rp" (pearson), or '
'"rq" (quantile)'
)
self._residuals = residuals
@property
def random_quantile(self) -> bool:
"""Whether to randomize the quantile residual."""
return self._random_quantile
@random_quantile.setter
def random_quantile(self, random_quantile: bool):
self._random_quantile = bool(random_quantile)
@property
def mark_outlier_residuals(self) -> bool:
"""Whether to mark outlier residuals with red crosses."""
return self._mark_outlier_residuals
@mark_outlier_residuals.setter
def mark_outlier_residuals(self, mark_outlier_residuals: bool):
self._mark_outlier_residuals = bool(mark_outlier_residuals)
@property
def residuals_ci_with_sign(self) -> bool:
"""Whether to take account residuals' sign when calculate CI bands."""
return self._residuals_ci_with_sign
@residuals_ci_with_sign.setter
def residuals_ci_with_sign(self, residuals_ci_with_sign: bool):
self._residuals_ci_with_sign = bool(residuals_ci_with_sign)
@property
def fill_residuals_ci(self) -> bool:
"""Whether to fill residuals' CI bands."""
return self._fill_residuals_ci
@fill_residuals_ci.setter
def fill_residuals_ci(self, fill_residuals_ci: bool):
self._fill_residuals_ci = bool(fill_residuals_ci)
@property
def plot_comps(self) -> bool:
"""Whether to plot additive components in spectral plot."""
return self._plot_comps
@plot_comps.setter
def plot_comps(self, plot_comps: bool):
self._plot_comps = bool(plot_comps)
@property
def seed(self) -> int | None:
"""Random seed used in calculation."""
return self._seed
@seed.setter
def seed(self, seed: int | None):
if seed is not None:
seed = int(seed)
self._seed = seed
[docs]
class Plotter(ABC):
"""Plotter to visualize fit results."""
_colors: dict | None = None
_palette: Any | None = None
_comps_latex: dict[str, str] | None = None
_params_latex: dict[str, str] | None = None
_supported: tuple[str, ...]
data: dict[str, PlotData] | None = None
def __init__(self, result: FitResult, config: PlotConfig = None):
self._result = result
self.data = self.get_plot_data(result)
self.config = config
markers = get_markers(len(self.data))
self._markers = dict(zip(self.data.keys(), markers, strict=True))
@abstractmethod
def __call__(self, plots: str = 'data ne r') -> dict[str, Figure]:
pass
[docs]
@abstractmethod
def plot_corner(self, *args, **kwargs) -> Figure:
"""Corner plot of bootstrap/posterior parameters."""
pass
[docs]
@staticmethod
@abstractmethod
def get_plot_data(result: FitResult) -> dict[str, PlotData]:
"""Get PlotData from FitResult."""
pass
@property
def config(self) -> PlotConfig:
"""Plotting configuration."""
return self._config
@config.setter
def config(self, config: PlotConfig):
if config is None:
config = PlotConfig()
elif not isinstance(config, PlotConfig):
raise TypeError('config must be a PlotConfig instance')
self._config = config
[docs]
def set_colors(self, colors: dict[str, Any] | None = None):
"""Specify the colors of data points used in the plots.
Parameters
----------
colors : dict or None, optional
If a dict is provided, will use the provided values as colors.
If None, the default colors will be used. The default is None.
"""
if colors is None:
self._palette = None
self._colors = None
else:
colors = dict(colors)
if not set(self.data.keys()).issubset(colors.keys()):
missing = ', '.join(i for i in self.data if i not in colors)
raise ValueError(f'missing colors for those data: {missing}')
self._palette = 'customized'
self._colors = colors
@property
def colors(self):
"""Plotting color for each data."""
if self._palette not in ('customized', self.config.palette):
colors = get_colors(len(self.data), palette=self.config.palette)
self._palette = self.config.palette
self._colors = dict(zip(self.data.keys(), colors, strict=True))
return self._colors
@property
def ndata(self):
"""Data points number."""
ndata = {name: data.ndata for name, data in self.data.items()}
ndata['total'] = sum(ndata.values())
return ndata
@property
def comps_latex(self) -> dict[str, str]:
"""LaTeX representation of components."""
if self._comps_latex is None:
self._comps_latex = {
k: f'${v}$ ' if v else ''
for k, v in self._result._helper.params_comp_latex.items()
}
return self._comps_latex
@property
def params_latex(self) -> dict[str, str]:
"""LaTeX representation of parameters."""
if self._params_latex is None:
self._params_latex = {
k: f'${v}$'
for k, v in self._result._helper.params_latex.items()
}
return self._params_latex
@property
def params_unit(self) -> dict[str, str]:
"""Unit of parameters."""
return self._result._helper.params_unit
@property
def params_titles(self) -> dict[str, str]:
"""Title of parameters."""
comps_latex = self.comps_latex
params_latex = self.params_latex
params = self._result._helper.params_names['all']
return {p: comps_latex[p] + params_latex[p] for p in params}
@property
def params_labels(self) -> dict[str, str]:
"""Label of parameters."""
comps_latex = self.comps_latex
params_latex = self.params_latex
params_unit = {
k: f'\n[{v}]' if v else v for k, v in self.params_unit.items()
}
params = self._result._helper.params_names['all']
return {
p: comps_latex[p] + params_latex[p] + params_unit[p]
for p in params
}
[docs]
def set_xlabel(self, ax: Axes):
ax.set_xlabel(r'$\mathrm{Energy\ [keV]}$')
[docs]
def plot_spec(
self,
data: bool = True,
ne: bool = True,
ene: bool = False,
eene: bool = False,
residuals: bool | Literal['rd', 'rp', 'rq'] = True,
*,
egrid: Mapping[str, NumPyArray] | None = None,
params: Mapping[str, float | int | Array] | None = None,
label_Fv: bool = False,
label_vFv: bool = False,
label_FvJy: bool = False,
) -> Figure:
r"""Spectral plot.
Parameters
----------
data : bool, optional
Whether to plot folded model and data. The default is ``True``.
ne : bool, optional
Whether to plot :math:`N(E)`. The default is ``True``.
ene : bool, optional
Whether to plot :math:`E N(E)`. The default is ``False``.
eene : bool, optional
Whether to plot :math:`E^2 N(E)`. The default is ``False``.
residuals : bool or {'rd', 'rp', 'rq'}, optional
Whether to plot residuals. Available options are:
* ``True``: plot default residuals
* ``False``: do not plot residuals
* ``'rd'``: plot deviance residuals
* ``'rp'``: plot Pearson residuals
* ``'rq'``: plot quantile residuals
The default is ``True``.
egrid : dict, optional
Overwrite the photon energy grid when plotting unfolded model.
params : dict, optional
Overwrite the photon energy grid when plotting unfolded model.
label_Fv : bool, optional
Whether to label the y-axis of :math:`E N(E)` plot as
:math:`F_{\nu}`. The default is ``False``.
label_vFv : bool, optional
Whether to label the y-axis of :math:`E^2 N(E)` plot as
:math:`\nu F_{\nu}`. The default is ``False``.
label_FvJy : bool, optional
Whether to label the y-axis of :math:`E N(E)` plot as
:math:`F_{\nu}` in Jy and convert values to Jy. The default is
``False``.
Returns
-------
Figure
The Figure object containing the spectral plot.
"""
nrows = data + ne + ene + eene + bool(residuals)
height_ratios = [1.618] * nrows
if residuals:
height_ratios[-1] = 1.0
config = self.config
fig, axs = plt.subplots(
nrows=nrows,
ncols=1,
sharex='all',
height_ratios=height_ratios,
gridspec_kw={'bottom': 0.07, 'top': 0.97, 'hspace': 0.03},
figsize=(8, 4 + nrows),
squeeze=False,
)
axs = axs.ravel()
fig.align_ylabels(axs)
for ax in axs:
ax.tick_params(
axis='both',
which='both',
direction='in',
bottom=True,
top=True,
left=True,
right=True,
)
plt.rcParams['axes.formatter.min_exponent'] = 3
self.set_xlabel(axs[-1])
plots = []
if data:
plots.append('ce')
if ne:
plots.append('ne')
if ene:
plots.append('ene')
if eene:
plots.append('eene')
residuals: Literal['rd', 'rp', 'rq'] | None
if residuals:
plots.append('residuals')
if residuals is True:
residuals = config.residuals
else:
residuals = None
axs_dict = dict(zip(plots, axs, strict=True))
yscale = config.yscale
if data:
ax = axs_dict['ce']
self.plot_ce(ax)
self.plot_folded(ax)
if yscale == 'linear':
ax.set_yscale('linear')
else:
ax.set_yscale('log')
dmin, dmax = ax.get_yaxis().get_data_interval()
vmin = ax.get_ylim()[0]
if yscale == 'linlog' and dmin <= 0.0:
lin_frac = config.lin_frac
if np.log10(dmax / vmin) > 7:
vmin = 1e-7 * dmax
scale = LinLogScale(
axis=None,
base=10.0,
lin_thresh=vmin,
lin_scale=get_scale(10.0, vmin, dmin, dmax, lin_frac),
)
ax.set_yscale(scale)
ax.axhline(vmin, c='k', lw=0.15, ls=':', zorder=-1)
else:
_adjust_log_range(ax, 'y', 7)
if ne:
self.plot_unfolded(axs_dict['ne'], 'ne', params, egrid)
if yscale != 'linear':
axs_dict['ne'].set_yscale('log')
_adjust_log_range(axs_dict['ne'], 'y')
if ene:
self.plot_unfolded(
axs_dict['ene'],
'ene',
params,
egrid,
label_Fv=label_Fv,
label_FvJy=label_FvJy,
)
if yscale != 'linear':
axs_dict['ene'].set_yscale('log')
_adjust_log_range(axs_dict['ene'], 'y')
if eene:
self.plot_unfolded(
axs_dict['eene'], 'eene', params, egrid, label_vFv=label_vFv
)
if yscale != 'linear':
axs_dict['eene'].set_yscale('log')
_adjust_log_range(axs_dict['eene'], 'y')
if residuals:
self.plot_residuals(axs_dict['residuals'], residuals)
axs[0].set_xscale(config.xscale)
intervalx = np.array([ax.dataLim.intervalx for ax in axs])
xmin = intervalx[:, 0].min()
xmax = intervalx[:, 1].max()
axs[0].set_xlim(xmin * 0.97, xmax * 1.06)
for ax in axs:
ax.relim()
ax.autoscale_view()
return fig
[docs]
def plot_unfolded(
self,
ax: Axes,
mtype: Literal['ne', 'ene', 'eene'],
params: Mapping[str, float | int | Array] | None = None,
egrid: Mapping[str, NumPyArray] | None = None,
label_Fv: bool = False,
label_vFv: bool = False,
label_FvJy: bool = False,
):
r"""Plot unfolded model.
Parameters
----------
ax : Axes
The Axes object to plot.
mtype : {'ne', 'ene', 'eene'}
The type of unfolded model, available options are:
* ``'ne'``: plot :math:`N(E)`
* ``'ene'``: plot :math:`E N(E)`
* ``'eene'``: plot :math:`E^2 N(E)`
params : dict, optional
Overwrite the parameters when plotting unfolded model.
egrid : dict, optional
Overwrite the photon energy grid when plotting unfolded model.
label_Fv : bool, optional
Whether to label the y-axis of :math:`E N(E)` plot as
:math:`F_{\nu}`. The default is ``False``.
label_vFv : bool, optional
Whether to label the y-axis of :math:`E^2 N(E)` plot as
:math:`\nu F_{\nu}`. The default is ``False``.
label_FvJy : bool, optional
Whether to label the y-axis of :math:`E N(E)` plot as
:math:`F_{\nu}` in Jy and convert values to Jy. The default is
``False``.
"""
params = dict(params) if params is not None else {}
if params:
if any(np.shape(v) != () for v in params.values()):
raise ValueError('params must be scalars')
egrid = dict(egrid) if egrid is not None else {}
config = self.config
colors = self.colors
cl = config.cl
comps = config.plot_comps
step_kwargs = {'lw': 1.618, 'alpha': config.alpha}
ribbon_kwargs = {'lw': 0.618, 'alpha': 0.2 * config.alpha}
if label_FvJy and mtype != 'ene':
raise ValueError('label_FvJy requires mtype="ene"')
if mtype == 'ne':
label_type = 'ne'
elif mtype == 'ene':
if label_FvJy:
label_type = 'FvJy'
else:
label_type = 'Fv' if label_Fv else 'ene'
elif mtype == 'eene':
label_type = 'vFv' if label_vFv else 'eene'
else:
raise ValueError("mtype must be 'ne', 'ene', or 'eene'")
ax.set_ylabel(config._YLABLES[label_type])
for name, data in self.data.items():
color = colors[name]
egrid_ = egrid.get(name, data.photon_egrid)
ne, ci = data.unfolded_model(mtype, egrid_, params, False, cl)
if label_FvJy:
ne = _scale_fv_to_jy(ne)
ci = _scale_fv_to_jy(ci)
_plot_step(
ax, egrid_[:-1], egrid_[1:], ne, color=color, **step_kwargs
)
if ci is not None:
_plot_ribbon(
ax,
egrid_[:-1],
egrid_[1:],
ci,
color=color,
**ribbon_kwargs,
)
if comps:
if not data.has_comps:
continue
ne, ci = data.unfolded_model(mtype, egrid_, params, True)
if label_FvJy:
ne = _scale_fv_to_jy(ne)
ci = _scale_fv_to_jy(ci)
for ne_ in ne.values():
_plot_step(
ax,
egrid_[:-1],
egrid_[1:],
ne_,
color=color,
**(step_kwargs | {'ls': ':'}),
)
if ci is not None:
for ci_ in ci.values():
_plot_ribbon(
ax,
egrid_[:-1],
egrid_[1:],
ci_,
color=color,
**ribbon_kwargs,
)
[docs]
def plot_folded(self, ax: Axes):
"""Plot folded model.
Parameters
----------
ax : Axes
The Axes object to plot.
"""
config = self.config
colors = self.colors
cl = config.cl
step_kwargs = {'lw': 1.618, 'alpha': config.alpha}
ribbon_kwargs = {'lw': 0.618, 'alpha': 0.2 * config.alpha}
ax.set_ylabel(config._YLABLES['ce'])
for name, data in self.data.items():
color = colors[name]
_plot_step(
ax,
data.channel_emin,
data.channel_emax,
data.ce_model,
color=color,
**step_kwargs,
)
quantiles = []
for i_cl in cl:
if (q := data.ce_model_ci(i_cl)) is not None:
quantiles.append(q)
if quantiles:
_plot_ribbon(
ax,
data.channel_emin,
data.channel_emax,
quantiles,
color=color,
**ribbon_kwargs,
)
[docs]
def plot_ce(self, ax: Axes):
"""Plot data.
Parameters
----------
ax : Axes
The Axes object to plot.
"""
config = self.config
colors = self.colors
alpha = config.alpha
xlog = config.xscale == 'log'
ax.set_ylabel(config._YLABLES['ce'])
for name, data in self.data.items():
color = colors[name]
marker = self._markers[name]
ax.errorbar(
x=data.channel_emean if xlog else data.channel_emid,
xerr=data.channel_errors if xlog else 0.5 * data.channel_width,
y=data.ce_data,
yerr=data.ce_errors,
alpha=alpha,
color=color,
fmt=f'{marker} ',
label=name,
lw=0.75,
ms=2.4,
mec=color,
mfc='#FFFFFFCC',
)
if len(self.data) > 5:
ncols = int(np.ceil(len(self.data) / 4))
else:
ncols = 1
ax.legend(ncols=ncols)
[docs]
def plot_residuals(
self,
ax: Axes,
rtype: Literal['rd', 'rp', 'rq'] | None = None,
):
"""Plot residuals.
Parameters
----------
ax : Axes
The Axes object to plot.
rtype : {'rd', 'rp', 'rq'}, optional
The type of residuals, available options are:
* ``'rd'``: deviance residuals
* ``'rp'``: Pearson residuals
* ``'rq'``: quantile residuals
"""
if rtype not in {'rd', 'rp', 'rq', None}:
raise ValueError(
'residuals type must be "rd" (deviance), "rp" (pearson), '
'"rq" (quantile), or None (use default residuals)'
)
config = self.config
colors = self.colors
cl = config.cl
random_quantile = config.random_quantile
with_sign = config.residuals_ci_with_sign
mark_outlier = config.mark_outlier_residuals
seed = config.seed
ribbon_kwargs = {'lw': 0.618, 'alpha': 0.15 * config.alpha}
if rtype is None:
rtype = config.residuals
alpha = config.alpha
xlog = config.xscale == 'log'
normal_q = stats.norm.isf(0.5 * (1.0 - cl))
ax.set_ylabel(config._YLABLES[rtype])
for name, data in self.data.items():
color = colors[name]
marker = self._markers[name]
x = data.channel_emean if xlog else data.channel_emid
xerr = data.channel_errors if xlog else 0.5 * data.channel_width
quantiles = []
for i_cl in cl:
q = data.residuals_ci(
rtype, i_cl, seed, random_quantile, with_sign
)
if q is not None:
quantiles.append(q)
if self.config.fill_residuals_ci:
if quantiles:
_plot_ribbon(
ax,
data.channel_emin,
data.channel_emax,
quantiles,
color=color,
**ribbon_kwargs,
)
else:
for q in normal_q:
ax.fill_between(
[data.channel_emin[0], data.channel_emax[-1]],
-q,
q,
color=color,
**ribbon_kwargs,
)
use_mle = True if quantiles else False
r = data.residuals(rtype, seed, config.random_quantile, use_mle)
if rtype == 'rq':
r, lower, upper = r
else:
lower = upper = False
ax.errorbar(
x=x,
y=r,
yerr=1.0,
xerr=xerr,
color=color,
alpha=alpha,
linewidth=0.75,
linestyle='',
marker=marker,
markersize=2.4,
markeredgecolor=color,
markerfacecolor='#FFFFFFCC',
lolims=lower,
uplims=upper,
)
if mark_outlier:
if quantiles:
q = quantiles[-1]
else:
q = [-normal_q[-1], normal_q[-1]]
mask = (r < q[0]) | (r > q[1])
ax.scatter(x[mask], r[mask], marker='x', c='r')
for q in normal_q:
ax.axhline(q, ls=':', lw=1, c='gray', zorder=0)
ax.axhline(-q, ls=':', lw=1, c='gray', zorder=0)
ax.axhline(0, ls='--', lw=1, c='gray', zorder=0)
yabs_max = abs(max(ax.get_ylim(), key=abs))
ax.set_ylim(ymin=-yabs_max, ymax=yabs_max)
[docs]
def plot_qq(
self,
rtype: Literal['rd', 'rp', 'rq'] | None = None,
seed: int | None = None,
detrend: bool = True,
) -> Figure:
"""Quantile-Quantile plot.
Parameters
----------
rtype : {'rd', 'rp', 'rq'}, optional
The type of residuals, available options are:
* ``'rd'``: deviance residuals
* ``'rp'``: Pearson residuals
* ``'rq'``: quantile residuals
* ``None``: use the default residuals type
The default is ``None``.
seed : int, optional
Random seed used in calculation. The default is ``None``.
detrend : bool, optional
Whether to detrend the Q-Q plot. The default is ``True``.
Returns
-------
Figure
The Figure object containing Q-Q plot.
"""
config = self.config
random_quantile = config.random_quantile
if rtype is None:
rtype = config.residuals
rsim = {
name: data.residuals_sim(rtype, seed, random_quantile)
for name, data in self.data.items()
}
if any(i is None for i in rsim.values()):
rsim['total'] = None
else:
rsim['total'] = np.hstack(list(rsim.values()))
use_mle = True if rsim['total'] is not None else False
r = {
name: data.residuals(rtype, seed, random_quantile, use_mle)
for name, data in self.data.items()
}
if rtype == 'rq':
r = {k: v[0] for k, v in r.items()}
r['total'] = np.hstack(list(r.values()))
n_subplots = len(self.data)
if n_subplots == 1:
ncols = 1
else:
ncols = n_subplots // 2
if n_subplots % 2:
ncols += 1
fig = plt.figure(figsize=(4 + ncols * 2.25, 4), tight_layout=True)
gs1 = fig.add_gridspec(1, 2, width_ratios=[4, ncols * 2.25])
gs2 = gs1[0, 1].subgridspec(2, ncols, wspace=0.35)
ax1 = fig.add_subplot(gs1[0, 0])
ax2 = gs2.subplots(squeeze=False).ravel()
if n_subplots % 2:
ax2[-1].set_visible(False)
ax2 = ax2[: len(self.data)]
ax1.set_xlabel('Normal Theoretical Quantiles')
ax1.set_ylabel('Residuals')
alpha = config.alpha
ha = 'center' if detrend else 'left'
text_x = 0.5 if detrend else 0.03
axs = {'total': ax1} | dict(zip(self.data.keys(), ax2, strict=True))
colors = {'total': 'k'} | self.colors
for name, ax in axs.items():
color = colors[name]
theor, q, line, lo, up = _get_qq(
r[name], detrend, 0.95, rsim[name]
)
ax.scatter(theor, q, s=5, color=color, alpha=alpha)
ax.plot(theor, line, ls='--', color=color, alpha=alpha)
ax.plot(theor, lo, ls=':', color=color, alpha=alpha)
ax.plot(theor, up, ls=':', color=color, alpha=alpha)
ax.fill_between(
theor, lo, up, alpha=0.2 * alpha, color=color, lw=0.0
)
ax.annotate(
name,
xy=(text_x, 0.97),
xycoords='axes fraction',
ha=ha,
va='top',
color=color,
)
return fig
[docs]
def plot_pit(self, detrend=True) -> Figure:
"""Probability integral transformation empirical CDF plot.
Parameters
----------
detrend : bool, optional
Whether to detrend the PIT ECDF plot. The default is ``True``.
Returns
-------
Figure
The Figure object containing PIT ECDF plot.
"""
config = self.config
pit = {name: data.pit()[1] for name, data in self.data.items()}
pit['total'] = np.hstack(list(pit.values()))
n_subplots = len(self.data)
if n_subplots == 1:
ncols = 1
else:
ncols = n_subplots // 2
if n_subplots % 2:
ncols += 1
fig = plt.figure(figsize=(4 + ncols * 2.25, 4), tight_layout=True)
gs1 = fig.add_gridspec(1, 2, width_ratios=[4, ncols * 2.25])
gs2 = gs1[0, 1].subgridspec(2, ncols, wspace=0.35)
ax1 = fig.add_subplot(gs1[0, 0])
ax2 = gs2.subplots(squeeze=False).ravel()
if n_subplots % 2:
ax2[-1].set_visible(False)
ax2 = ax2[: len(self.data)]
ax1.set_xlabel('PIT Value')
ax1.set_ylabel('ECDF')
alpha = config.alpha
ha = 'right' if detrend else 'left'
text_x = 0.97 if detrend else 0.03
axs = {'total': ax1} | dict(zip(self.data.keys(), ax2, strict=True))
colors = {'total': 'k'} | self.colors
for name, ax in axs.items():
color = colors[name]
x, y, line, lower, upper = _get_pit_ecdf(pit[name], 0.95, detrend)
ax.plot(x, line, ls='--', color=color, alpha=alpha)
ax.fill_between(
x, lower, upper, alpha=0.2 * alpha, color=color, step='mid'
)
ax.step(x, y, alpha=alpha, color=color, where='mid')
ax.annotate(
text=name,
xy=(text_x, 0.97),
xycoords='axes fraction',
ha=ha,
va='top',
color=color,
)
return fig
[docs]
def plot_gof(self) -> Figure:
"""Plot distribution of GOF statistics and p-value.
Returns
-------
Figure
The Figure object containing GOF statistics plot.
"""
if isinstance(self, MLEResultPlotter):
if self._result._boot is None:
raise RuntimeError(
'MLEResult.boot() must be called to assess gof'
)
n = int(self._result._boot.n_valid)
dev_obs = self._result.deviance
dev_sim = self._result._boot.deviance
dev_sim = dev_sim['group'] | {'total': dev_sim['total']}
p_value = self._result._boot.p_value
elif isinstance(self, PosteriorResultPlotter):
if self._result._ppc is None:
raise RuntimeError(
'PosteriorResult.ppc() must be called to assess gof'
)
n = int(self._result._ppc.n_valid)
dev_obs = self._result._mle['deviance']
dev_sim = self._result._ppc.deviance
dev_obs = dev_obs['group'] | {'total': dev_obs['total']}
dev_sim = dev_sim['group'] | {'total': dev_sim['total']}
p_value = self._result._ppc.p_value
else:
raise NotImplementedError
p_value = p_value['group'] | {'total': p_value['total']}
n_subplots = len(self.data)
if n_subplots == 1:
ncols = 1
else:
ncols = n_subplots // 2
if n_subplots % 2:
ncols += 1
fig = plt.figure(figsize=(4 + ncols * 2.25, 4), tight_layout=True)
gs1 = fig.add_gridspec(1, 2, width_ratios=[4, ncols * 2.25])
gs2 = gs1[0, 1].subgridspec(2, ncols, wspace=0.35)
ax1 = fig.add_subplot(gs1[0, 0])
ax2 = gs2.subplots(squeeze=False).ravel()
if n_subplots % 2:
ax2[-1].set_visible(False)
ax2 = ax2[: len(self.data)]
ax1.set_xlabel('$D$')
ax1.set_ylabel(r'$P(\mathcal{D} \geq D)$')
axs = {'total': ax1} | dict(zip(self.data.keys(), ax2, strict=True))
colors = {'total': 'k'} | self.colors
for name, ax in axs.items():
color = colors[name]
d_obs = dev_obs[name]
d_sim = np.sort(dev_sim[name])
sf = 1.0 - np.arange(1.0, n + 1.0) / n
ax.plot(d_sim, sf, color=color)
ax.axvline(d_obs, color=color, ls=':')
p = p_value[name]
if p > 0.0:
pstr = f'{name} $p = {p:.2g}$'
else:
pstr = f'{name} $p < {1}/{n}$'
ax.annotate(
text=pstr,
xy=(0.97, 0.97),
xycoords='axes fraction',
ha='right',
va='top',
color=color,
)
ax.set_yscale('log')
return fig
[docs]
class MLEResultPlotter(Plotter):
data: dict[str, MLEPlotData]
_result: MLEResult
_supported = (
'data',
'r',
'rd',
'rp',
'rq',
'ne',
'ene',
'eene',
'Fv',
'FvJy',
'vFv',
'corner',
'gof',
'qq',
'pit',
)
def __call__(self, plots: str = 'data ne r') -> dict[str, Figure]:
r"""Plot MLE fit results.
Parameters
----------
plots : str, optional
Plots to show, available plots are:
* ``'data'``: data and folded model plot
* ``'ne'``: :math:`N(E)` model plot
* ``'ene'``: :math:`E N(E)` model plot
* ``'eene'``: :math:`E^2 N(E)` model plot
* ``'Fv'``: :math:`F_{\nu}` model plot
* ``'FvJy'``: :math:`F_{\nu}` model plot in Jy
* ``'vFv'``: :math:`\nu F_{\nu}` model plot
* ``'r'``: default residuals plot
* ``'rd'``: deviance residuals plot
* ``'rp'``: Pearson residuals plot
* ``'rq'``: quantile residuals plot
* ``'corner'``: corner plot
* ``'gof'``: goodness-of-fit statistics plot
* ``'qq'``: quantiles-quantiles plot of residuals
* ``'pit'``: probability integral transform plot of spectral
data
Multiple plots can be combined by separating them with whitespace.
THe default is ``'data ne r'``.
Returns
-------
dict
Dictionary containing Figure object for each plot.
"""
plots = re.split(r'\s+', str(plots))
if any(p not in self._supported for p in plots):
supported = ', '.join(self._supported)
err = ', '.join(p for p in plots if p not in self._supported)
raise ValueError(f'supported plots are: {supported}; got {err}')
plots_set = set(plots)
dic = {}
spec = {
'data',
'ne',
'ene',
'eene',
'Fv',
'FvJy',
'vFv',
'r',
'rd',
'rp',
'rq',
}
if spec & plots_set:
residuals = [i for i in plots if i in ('r', 'rd', 'rp', 'rq')]
if residuals:
residuals = residuals[-1]
if residuals == 'r':
residuals = True
else:
residuals = False
dic['spec'] = self.plot_spec(
data='data' in plots,
ne='ne' in plots,
ene=bool({'ene', 'Fv', 'FvJy'} & plots_set),
eene=bool({'eene', 'vFv'} & plots_set),
residuals=residuals,
label_Fv='Fv' in plots,
label_FvJy='FvJy' in plots,
label_vFv='vFv' in plots,
)
if 'corner' in plots_set:
dic['corner'] = self.plot_corner()
if 'gof' in plots_set:
dic['gof'] = self.plot_gof()
if 'qq' in plots_set:
dic['qq'] = self.plot_qq()
if 'pit' in plots_set:
dic['pit'] = self.plot_pit()
return dic
[docs]
@staticmethod
def get_plot_data(result: MLEResult) -> dict[str, MLEPlotData]:
helper = result._helper
keys = jax.random.split(
jax.random.PRNGKey(helper.seed['resd']), len(helper.data_names)
)
data = {
name: MLEPlotData(name, result, int(key[0]))
for name, key in zip(helper.data_names, keys, strict=True)
}
return data
[docs]
def plot_corner(
self,
params: str | Sequence[str] | None = None,
color: str | None = None,
bins: int | Sequence[int] = 40,
hist_bin_factor: float | Sequence[float] = 1.5,
fig_path: str | None = None,
) -> Figure:
"""Corner plot of bootstrap parameters.
Parameters
----------
params : str or sequence of str, optional
Parameters to plot. The default is all spectral parameters.
color : str, optional
Color of the plot. The default is ``None``.
bins : int or list of int, optional
The number of bins to use in histograms, either as a fixed value
for all dimensions or as a list of integers for each dimension.
The default is 40.
hist_bin_factor : float or list of float, optional
This is a factor (or list of factors, one for each dimension)
that will multiply the bin specifications when making the 1-D
histograms. This is generally used to increase the number of
bins in the 1-D plots to provide more resolution.
The default is 1.5.
fig_path : str, optional
Path to save the figure. The default is ``None``.
Returns
-------
Figure
The Figure object containing corner plot.
"""
if self._result._boot is None:
raise ValueError('MLEResult.boot() must be called first')
helper = self._result._helper
params = check_params(params, helper)
axes_scale = [
'log' if helper.params_log[p] else 'linear' for p in params
]
params_titles = self.params_titles
params_labels = self.params_labels
fig = plot_corner(
idata=az.from_dict(self._result._boot.params),
bins=bins,
hist_bin_factor=hist_bin_factor,
params=params,
axes_scale=axes_scale,
levels=self.config.cl,
titles=[params_titles[p] for p in params],
labels=[params_labels[p] for p in params],
color=color,
)
if fig_path:
fig.savefig(fig_path, bbox_inches='tight')
return fig
[docs]
class PosteriorResultPlotter(Plotter):
data: dict[str, PosteriorPlotData]
_result: PosteriorResult
_supported = (
'data',
'r',
'rd',
'rp',
'rq',
'ne',
'ene',
'eene',
'Fv',
'FvJy',
'vFv',
'corner',
'gof',
'qq',
'pit',
'trace',
'khat',
)
def __call__(self, plots: str = 'data ne r') -> dict[str, Figure]:
r"""Plot Bayesian fit results.
Parameters
----------
plots : str, optional
Plots to show, available plots are:
* ``'data'``: data and folded model plot
* ``'ne'``: :math:`N(E)` model plot
* ``'ene'``: :math:`E N(E)` model plot
* ``'eene'``: :math:`E^2 N(E)` model plot
* ``'Fv'``: :math:`F_{\nu}` model plot
* ``'FvJy'``: :math:`F_{\nu}` model plot in Jy
* ``'vFv'``: :math:`\nu F_{\nu}` model plot
* ``'r'``: default residuals plot
* ``'rd'``: deviance residuals plot
* ``'rp'``: Pearson residuals plot
* ``'rq'``: PSIS-LOO quantile residuals plot
* ``'corner'``: corner plot
* ``'gof'``: goodness-of-fit statistics plot
* ``'qq'``: quantiles-quantiles plot of residuals
* ``'pit'``: PSIS-LOO probability integral transform plot of
spectral data
* ``'trace'``: trace plot of posterior samples
* ``'khat'``: k-hat plot for Bayesian PSIS-LOO diagnostics
Multiple plots can be combined by separating them with whitespace.
THe default is ``'data ne r'``.
Returns
-------
dict
Dictionary containing Figure object for each plot.
"""
plots = re.split(r'\s+', str(plots))
if any(p not in self._supported for p in plots):
supported = ', '.join(self._supported)
err = ', '.join(p for p in plots if p not in self._supported)
raise ValueError(f'supported plots are: {supported}; got {err}')
plots_set = set(plots)
dic = {}
spec = {
'data',
'ne',
'ene',
'eene',
'Fv',
'FvJy',
'vFv',
'r',
'rd',
'rp',
'rq',
}
if spec & plots_set:
residuals = [i for i in plots if i in ('r', 'rd', 'rp', 'rq')]
if residuals:
residuals = residuals[-1]
if residuals == 'r':
residuals = True
else:
residuals = False
dic['spec'] = self.plot_spec(
data='data' in plots,
ne='ne' in plots,
ene=bool({'ene', 'Fv', 'FvJy'} & plots_set),
eene=bool({'eene', 'vFv'} & plots_set),
residuals=residuals,
label_Fv='Fv' in plots,
label_FvJy='FvJy' in plots,
label_vFv='vFv' in plots,
)
if 'corner' in plots_set:
dic['corner'] = self.plot_corner()
if 'gof' in plots_set:
dic['gof'] = self.plot_gof()
if 'qq' in plots_set:
dic['qq'] = self.plot_qq()
if 'pit' in plots_set:
dic['pit'] = self.plot_pit()
if 'trace' in plots_set:
dic['trace'] = self.plot_trace()
if 'khat' in plots_set:
dic['khat'] = self.plot_khat()
return dic
[docs]
@staticmethod
def get_plot_data(result: PosteriorResult) -> dict[str, PosteriorPlotData]:
helper = result._helper
keys = jax.random.split(
jax.random.PRNGKey(helper.seed['resd']), len(helper.data_names)
)
data = {
name: PosteriorPlotData(name, result, int(key[0]))
for name, key in zip(helper.data_names, keys, strict=True)
}
return data
[docs]
def plot_trace(
self,
params: str | Sequence[str] | None = None,
fig_path: str | None = None,
) -> Figure:
"""Plot trace plot of posterior samples.
Parameters
----------
params : str or sequence of str, optional
Parameters to plot. The default is all spectral parameters.
fig_path : str, optional
Path to save the figure. The default is ``None``.
"""
helper = self._result._helper
params = check_params(params, helper)
axes_scale = [
'log' if helper.params_log[p] else 'linear' for p in params
]
params_labels = self.params_labels
labels = [params_labels[p] for p in params]
fig = plot_trace(self._result._idata, params, axes_scale, labels)
if fig_path:
fig.savefig(fig_path, bbox_inches='tight')
return fig
[docs]
def plot_corner(
self,
params: str | Sequence[str] | None = None,
color: str | None = None,
divergences: bool = True,
bins: int | Sequence[int] = 40,
hist_bin_factor: float | Sequence[float] = 1.5,
fig_path: str | None = None,
) -> Figure:
"""Corner plot of posterior parameters.
Parameters
----------
params : str or sequence of str, optional
Parameters to plot. The default is all spectral parameters.
color : str, optional
Color of the plot. The default is ``None``.
divergences : bool, optional
Whether to show divergent samples. The default is ``True``.
bins : int or list of int, optional
The number of bins to use in histograms, either as a fixed value
for all dimensions, or as a list of integers for each dimension.
The default is 40.
hist_bin_factor : float or list of float, optional
This is a factor (or list of factors, one for each dimension)
that will multiply the bin specifications when making the 1-D
histograms. This is generally used to increase the number of
bins in the 1-D plots to provide more resolution.
The default is 1.5.
fig_path : str, optional
Path to save the figure. The default is ``None``.
Returns
-------
Figure
The Figure object containing corner plot.
"""
helper = self._result._helper
params = check_params(params, helper)
axes_scale = [
'log' if helper.params_log[p] else 'linear' for p in params
]
params_titles = self.params_titles
params_labels = self.params_labels
fig = plot_corner(
idata=self._result._idata,
bins=bins,
hist_bin_factor=hist_bin_factor,
params=params,
axes_scale=axes_scale,
levels=self.config.cl,
titles=[params_titles[p] for p in params],
labels=[params_labels[p] for p in params],
color=color,
divergences=divergences,
)
if fig_path:
fig.savefig(fig_path, bbox_inches='tight')
return fig
[docs]
def plot_khat(self) -> Figure:
"""Plot k-hat diagnostic of PSIS-LOO."""
config = self.config
colors = self.colors
alpha = config.alpha
xlog = config.xscale == 'log'
fig, ax = plt.subplots(1, 1, squeeze=True, tight_layout=True)
khat = self._result.loo.pareto_k
if np.any(khat.values > 0.7):
ax.axhline(0.7, color='r', lw=0.5, ls=':')
for name, data in self.data.items():
color = colors[name]
marker = self._markers[name]
khat_data = khat.sel(channel=data.channel).values
x = data.channel_emean if xlog else data.channel_emid
ax.errorbar(
x=x,
xerr=data.channel_errors if xlog else 0.5 * data.channel_width,
y=khat_data,
alpha=alpha,
color=color,
fmt=f'{marker} ',
label=name,
lw=0.75,
ms=2.4,
mec=color,
mfc='#FFFFFFCC',
)
mask = khat_data > 0.7
if np.any(mask):
ax.scatter(x=x[mask], y=khat_data[mask], marker='x', c='r')
ax.set_xscale(config.xscale)
ax.set_xlabel('Energy [keV]')
ax.set_ylabel(r'Shape Parameter $\hat{k}$')
return fig