"""Helper for fitting and analysis."""
from __future__ import annotations
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from jax import lax
from numpyro import handlers
from numpyro.infer.util import constrain_fn, unconstrain_fn
from elisa.infer.likelihood import (
_STATISTIC_BACK_NORMAL,
_STATISTIC_SPEC_NORMAL,
_STATISTIC_WITH_BACK,
chi2,
cstat,
pgstat,
pstat,
wstat,
)
from elisa.util.config import get_parallel_number
from elisa.util.misc import (
get_unit_latex,
progress_bar_factory,
)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Literal
from numpyro.distributions import Distribution
from elisa.data.base import FixedData
from elisa.infer.fit import Fit
from elisa.infer.likelihood import Statistic
from elisa.models.model import CompiledModel, ModelInfo, ParamSetup
from elisa.util.typing import (
JAXArray,
JAXFloat,
ParamID,
ParamName,
ParamNameValMapping,
)
[docs]
def check_params(
params: str | Sequence[str] | None, helper: Helper
) -> list[str]:
params_names = helper.params_names
all_params = set(params_names['all']) | set(helper.params_setup)
forwarded = {
k: v[0]
for k, v in helper.params_setup.items()
if v[1].name == 'Forwarded'
}
fixed = [k for k, v in helper.params_setup.items() if v[1].name == 'Fixed']
integrated = [
k for k, v in helper.params_setup.items() if v[1].name == 'Integrated'
]
if params is None:
params = set(params_names['interest'])
elif isinstance(params, str):
# check if params exist
if params not in all_params:
raise ValueError(f'parameter {params} is not exist')
params = {params}
elif isinstance(params, Iterable):
# check if params exist
params = {str(i) for i in params}
if not params.issubset(all_params):
params_err = params - set(params_names['all'])
raise ValueError(f'parameters: {params_err} are not exist')
else:
raise ValueError('params must be str, or sequence of str')
if params_err := params.intersection(forwarded):
forwarded = {i: forwarded[i] for i in params_err}
info = ', '.join(f'{k} to {v}' for k, v in forwarded.items())
raise RuntimeError(
f"parameters are linked: {info}; corresponding parameters' "
'name should be used'
)
if params_err := params.intersection(fixed):
info = ', '.join(params_err)
raise RuntimeError(f'parameters are fixed: {info}')
if params_err := params.intersection(integrated):
info = ', '.join(params_err)
raise RuntimeError(f'parameters are integrated-out: {info}')
return sorted(params, key=params_names['all'].index)
# def get_reparam(dist: Distribution) -> tuple[Reparam, Callable] | None:
# """Get reparam for a distribution."""
# # TODO: support more reparameterizers
# if isinstance(dist, (Normal, StudentT, Cauchy)):
# param = '_centered'
# suffix = '_decentered'
# reparam = LocScaleReparam()
# elif isinstance(dist, TransformedDistribution):
# suffix = '_base'
# reparam = TransformReparam()
# transforms = dist.transforms
# def inv(x):
# """Inverse transformation."""
# for t in reversed(transforms):
# x = t.inv(x)
# return x
# elif isinstance(dist, ProjectedNormal):
# suffix = '_normal'
# reparam = ProjectedNormalReparam()
# elif isinstance(dist, VonMises):
# suffix = '_unwrapped'
# reparam = CircularReparam()
# else:
# return None
#
# return reparam, inv
[docs]
def get_helper(fit: Fit) -> Helper:
"""Get helper functions for fitting."""
# JAX devices must be set before importing optimistix
# so we import it here
import optimistix as optx
model_info: ModelInfo = fit._model_info
data: dict[str, FixedData] = fit._data
model: dict[str, CompiledModel] = fit._model
stat: dict[str, Statistic] = fit._stat
seed0 = fit._seed
rng_seed: dict[str, int] = {
'mcmc': seed0, # for MCMC
'pred': seed0 + 1, # for data simulation
'resd': seed0 + 2, # for random quantile residuals
}
# channel number of data
ndata = {k: v.channel.size for k, v in data.items()}
ndata['total'] = sum(ndata.values())
# channel information
channels = {f'{k}_channel': v.channel for k, v in data.items()}
channels['channel'] = np.hstack(list(channels.values()))
# number of free parameters
nparam = len(model_info.sample)
# degree of freedom
dof = ndata['total'] - nparam
# ======================== count data calculator ==========================
on_names = [f'{i}_Non' for i in data]
off_names = [f'{i}_Noff' for i in data if stat[i] in _STATISTIC_WITH_BACK]
back_ratio = {
k: v.back_ratio
for k, v in data.items()
if stat[k] in _STATISTIC_WITH_BACK
}
spec_unit = {
k: 1.0 / (v.spec_exposure * v.channel_width) for k, v in data.items()
}
@jax.jit
def get_counts_data(counts: dict[str, JAXArray]) -> dict[str, JAXArray]:
"""Get count data, including "on", "off" and net counts."""
counts = {k: jnp.asarray(v, float) for k, v in counts.items()}
assert set(counts) == set(on_names + off_names)
# counts in the "on" measurement of each dataset
counts_data = {i: counts[i] for i in on_names}
# counts in the "off" measurement of each dataset
counts_data |= {i: counts[i] for i in off_names}
net_counts = {
i: (
counts[f'{i}_Non']
- back_ratio.get(i, 0.0) * counts.get(f'{i}_Noff', 0.0)
)
for i in data
}
# net spectrum [counts s^-1 keV^-1]
counts_data |= {i: net_counts[i] * spec_unit[i] for i in data.keys()}
# stack net spectrum of all channels of all datasets
counts_data['channels'] = jnp.concatenate(
[counts_data[i] for i in data.keys()],
axis=-1,
)
# total net counts of datasets
counts_data['total'] = jnp.sum(
jnp.asarray([i.sum(axis=-1) for i in net_counts.values()]), axis=0
)
return counts_data
# ======================== count data calculator ==========================
obs_counts = {
f'{k}_Non': (
v.net_counts
if stat[k] in _STATISTIC_SPEC_NORMAL
else v.spec_counts
)
for k, v in data.items()
}
obs_counts |= {
f'{k}_Noff': v.back_counts
for k, v in data.items()
if stat[k] in _STATISTIC_WITH_BACK
}
obs_data = get_counts_data(obs_counts)
# ======================== count data simulator ===========================
def simulator_factory(data_dist: Literal['norm', 'poisson'], *dist_args):
"""Factory to create data simulator."""
def simulator(
rng: np.random.Generator,
model_values: np.ndarray,
n: int,
):
"""Simulate data given random number generator and model values."""
if n != 1:
shape = (n,) + model_values.shape
else:
shape = model_values.shape
if data_dist == 'norm':
# TODO: fix the negative counts by setting them to zeros
return rng.normal(model_values, *dist_args, shape)
elif data_dist == 'poisson':
return rng.poisson(model_values, shape)
else:
raise NotImplementedError(f'{data_dist = }')
return simulator
simulators = {}
sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]] = {}
for k, s in stat.items():
d = data[k]
name = f'{k}_Non'
if s in _STATISTIC_SPEC_NORMAL:
simulators[name] = simulator_factory('norm', d.spec_errors)
sampling_dist[name] = ('norm', (d.spec_errors,))
else:
simulators[name] = simulator_factory('poisson')
sampling_dist[name] = ('poisson', ())
if s in _STATISTIC_WITH_BACK:
name = f'{k}_Noff'
if s in _STATISTIC_BACK_NORMAL:
simulators[name] = simulator_factory('norm', d.back_errors)
sampling_dist[name] = ('norm', (d.back_errors,))
else:
simulators[name] = simulator_factory('poisson')
sampling_dist[name] = ('poisson', ())
def simulate(
rng_seed: int,
model_values: dict[str, JAXArray],
n: int = 1,
) -> dict[str, JAXArray]:
"""Simulate data given model values.
Use numpy.random instead of numpyro.infer.Predictive for performance.
"""
models = {i: model_values[f'{i}_model'] for i in simulators.keys()}
rng = np.random.default_rng(int(rng_seed))
sim = {k: v(rng, models[k], n) for k, v in simulators.items()}
return get_counts_data(sim)
# ======================== count data simulator ===========================
# ======================== create numpyro model ===========================
pname_to_latex: dict[ParamName, str] = {
pname: model_info.latex[pid] for pid, pname in model_info.name.items()
}
pname_to_log: dict[ParamName, bool] = {
pname: model_info.log[pid] for pid, pname in model_info.name.items()
}
pname_to_unit: dict[ParamName, str] = {
pname: get_unit_latex(model_info.unit[pid], throw=False)
for pid, pname in model_info.name.items()
}
pname_to_comp_latex: dict[ParamName, str] = {
pname: model_info.pid_to_comp_latex[pid]
for pid, pname in model_info.name.items()
}
# get model parameters priors
pid_to_pname: dict[ParamID, ParamName] = model_info.name
pname_to_pid: dict[ParamName, ParamID] = {
v: k for k, v in pid_to_pname.items()
}
pid_to_prior: dict[ParamID, Distribution] = model_info.sample
params_prior: dict[ParamName, Distribution] = {
pid_to_pname[pid]: pid_to_prior[pid] for pid in pid_to_prior
}
# get deterministic value getter function
deterministic: dict[ParamID, Callable] = model_info.deterministic
# get the likelihood function for each dataset
likelihood_wrapper = {
'chi2': chi2,
'cstat': cstat,
'pstat': pstat,
'wstat': wstat,
'pgstat': pgstat,
}
likelihood: dict[str, Callable[[JAXArray], None]] = {
k: likelihood_wrapper[stat[k]](v, model[k].eval)
for k, v in data.items()
}
# get default re-parameterization of each parameter
# reparams: dict[ParamName, tuple[Reparam, Callable]] = {
# name: reparam_and_inv
# for name, prior in params_prior.items ()
# if (reparam_and_inv := get_reparam(prior)) is not None
# }
def numpyro_model(predictive: bool = False) -> None:
"""The numpyro model for spectral fitting."""
# TODO:
# figure out how to handle reparam transformation, so we can
# * give initial parameter value in the original space and
# transformed to reparameterized space,
# * find the classic confidence interval in the reparameterized
# space and then transform back to the original space.
# This is not trivial because transformation is not always bijective!
# The trick to handle confidence interval of composite parameter may
# also be used here to solve the second problem.
# with numpyro.handlers.reparam(config=reparams):
# params_name_values = {
# name: numpyro.sample(name, dist)
# for name, dist in params_prior.items()
# }
# get parameter value from prior
params_name_values = {
name: numpyro.sample(name, dist)
for name, dist in params_prior.items()
}
params_id_values = {
pname_to_pid[name]: value
for name, value in params_name_values.items()
}
# store composite parameters into chains
for pid, fn in deterministic.items():
numpyro.deterministic(pid_to_pname[pid], fn(params_id_values))
# the likelihood between observation and model for each dataset
jax.tree.map(
lambda f: f(params_name_values, predictive=predictive),
likelihood,
)
# ======================== create numpyro model ===========================
# =================== functions used in optimization ======================
params_names = [
f'{i[1]}.{i[2]}' if i[1] else i[2] for i in model_info.info if i[0]
]
interest_names = [f'{i[1]}.{i[2]}' for i in model_info.info if all(i[:2])]
free_names = sorted(params_prior.keys(), key=params_names.index)
deterministic_names = [pid_to_pname[i] for i in deterministic]
deterministic_names = sorted(deterministic_names, key=params_names.index)
# ensure if names are consistent
if set(free_names + deterministic_names) != set(params_names):
raise RuntimeError(
f'{params_names = }, {free_names = }, {deterministic_names = }'
)
data_group = {
k: (f'{k}_Non', f'{k}_Noff')
if v in _STATISTIC_WITH_BACK
else (f'{k}_Non',)
for k, v in stat.items()
}
@jax.jit
def arr_to_dic(arr: JAXArray) -> ParamNameValMapping:
"""Covert free parameters from an array to a dict."""
assert len(arr) == len(free_names)
return dict(zip(free_names, arr, strict=True))
@jax.jit
def dic_to_arr(dic: ParamNameValMapping) -> JAXArray:
"""Covert free parameters from a dict to an array."""
return jnp.array([dic[i] for i in free_names], float)
@jax.jit
def unconstr_arr_to_constr_dic(arr: JAXArray) -> ParamNameValMapping:
"""Covert free parameters array from unconstrained space to dict in
constrained space.
"""
return constrain_fn(
model=numpyro_model,
model_args=(),
model_kwargs={},
params=arr_to_dic(arr),
)
@jax.jit
def constr_arr_to_unconstr_dic(arr: JAXArray) -> ParamNameValMapping:
"""Covert free parameters array from constrained space to dict in
unconstrained space.
"""
return unconstrain_fn(
model=numpyro_model,
model_args=(),
model_kwargs={},
params=arr_to_dic(arr),
)
@jax.jit
def constr_arr_to_unconstr_arr(arr: JAXArray) -> JAXArray:
"""Covert free parameters array from constrained space into
unconstrained space.
"""
unconstr_dic = constr_arr_to_unconstr_dic(arr)
return jnp.asarray([unconstr_dic[i] for i in free_names])
@jax.jit
def constr_dic_to_unconstr_arr(dic: ParamNameValMapping) -> JAXArray:
"""Covert free parameters dict from constrained space to array in
unconstrained space.
"""
constr_arr = dic_to_arr(dic)
return constr_arr_to_unconstr_arr(constr_arr)
# get default value of each parameter
default_constr_dic = {
pid_to_pname[k]: v for k, v in model_info.default.items()
}
default_constr_dic = {k: default_constr_dic[k] for k in free_names}
default_constr_arr = jnp.array([default_constr_dic[i] for i in free_names])
default_unconstr_arr = constr_arr_to_unconstr_arr(default_constr_arr)
free_default: dict[str, dict[ParamName, JAXFloat] | JAXArray] = {
'constr_dic': default_constr_dic,
'constr_arr': default_constr_arr,
'unconstr_arr': default_unconstr_arr,
}
def get_sites(
unconstr_arr: JAXArray,
) -> dict[
Literal['params', 'models', 'loglike'], dict[str, JAXFloat | JAXArray]
]:
"""Get parameters in constrained space, models values and log
likelihood, given free parameters array in unconstrained space.
"""
sites = constrain_fn(
model=numpyro_model,
model_args=(),
model_kwargs={},
params=arr_to_dic(unconstr_arr),
return_deterministic=True,
)
params = get_params(sites)
models = get_models(sites)
loglike = get_loglike(sites)
return {'params': params, 'models': models, 'loglike': loglike}
def get_params(sites: Mapping) -> dict:
"""Get parameters' values in constrained space given numpyro model
sites.
"""
params = {i: sites[i] for i in params_names}
return params
def get_models(sites: Mapping) -> dict:
"""Get model values given numpyro model sites."""
models = {k: sites[k] for k in data_group.keys()}
models |= {
f'{i}_model': sites[f'{i}_model']
for v in data_group.values()
for i in v
}
return models
def get_loglike(sites: Mapping) -> dict:
"""Get log likelihood given numpyro model sites."""
loglike_data = {
i: sites[f'{i}_loglike'] for v in data_group.values() for i in v
}
loglike_point = {i: sites[f'{i}_loglike'] for i in data_group.keys()}
loglike_group = {k: v.sum(axis=-1) for k, v in loglike_point.items()}
loglike_channels = jnp.concatenate(
[loglike_point[i] for i in data_group.keys()], axis=-1
)
loglike_total = loglike_channels.sum(axis=-1)
loglike = {
'data': loglike_data,
'point': loglike_point,
'group': loglike_group,
'channels': loglike_channels,
'total': loglike_total,
}
return loglike
@jax.jit
def unconstr_dic_to_params_dic(
dic: ParamNameValMapping,
) -> ParamNameValMapping:
"""Get parameters dict in constrained space,
given a free parameters dict in unconstrained space.
"""
return jax.jit(get_sites)(dic_to_arr(dic))['params']
@jax.jit
def unconstr_arr_to_params_array(arr: JAXArray) -> JAXArray:
"""Get parameters dict in constrained space,
given a free parameters array in unconstrained space.
"""
unconstr_dic = arr_to_dic(arr)
params_dic = unconstr_dic_to_params_dic(unconstr_dic)
return jnp.array([params_dic[i] for i in params_names])
@jax.jit
def unconstr_covar(unconstr_arr: JAXArray) -> JAXArray:
"""Calculate covariance matrix of free parameters in unconstrained
space, given a free parameters array in unconstrained space.
"""
hess = jax.jit(jax.hessian(deviance_total))(unconstr_arr)
covar = jnp.linalg.inv(hess)
return 2.0 * covar
@jax.jit
def params_covar(
unconstr_arr: JAXArray,
unconstr_cov: JAXArray,
) -> jnp.ndarray:
"""Calculate covariance matrix of all parameters in constrained space,
given values and covariance matrix of free parameters in unconstrained
space.
"""
jac = jax.jit(jax.jacobian(unconstr_arr_to_params_array))(unconstr_arr)
return jac @ unconstr_cov @ jac.T
@jax.jit
def get_mle(unconstr_arr: JAXArray) -> tuple[JAXArray, JAXArray]:
"""Get the value and covariance matrix of all parameters in constrained
space, given MLE of free parameters in unconstrained space.
"""
params_arr = unconstr_arr_to_params_array(unconstr_arr)
params_cov = params_covar(unconstr_arr, unconstr_covar(unconstr_arr))
return params_arr, params_cov
# NOTE:
# the following functions will be used in simulation procedure,
# so we do not jit it here, or data substitution will fail
def loglike(unconstr_arr: JAXArray) -> dict[str, JAXArray]:
"""Calculate log-likelihood given free parameters dict in constrained
space.
"""
return get_sites(unconstr_arr)['loglike']
def deviance(unconstr_arr: JAXArray) -> dict:
"""Calculate total/group/point deviance given free parameters array in
unconstrained space.
"""
loglike_dic = loglike(unconstr_arr)
neg_double = jax.jit(lambda x: -2.0 * x)
point = jax.tree.map(neg_double, loglike_dic['point'])
group = jax.tree.map(neg_double, loglike_dic['group'])
total = jax.tree.map(neg_double, loglike_dic['total'])
return {'total': total, 'group': group, 'point': point}
def deviance_total(unconstr_arr: JAXArray) -> JAXFloat:
"""Calculate total deviance given free parameters array in
unconstrained space.
"""
return deviance(unconstr_arr)['total']
def residual(unconstr_arr: JAXArray) -> JAXArray:
"""Calculate deviance residual (i.e. sqrt deviance) given free
parameters array in unconstrained space.
"""
loglike_dic = loglike(unconstr_arr)
loglike_arr = jnp.hstack(list(loglike_dic['point'].values()))
return jnp.sqrt(-2.0 * loglike_arr)
# =================== functions used in optimization ======================
# =============== functions used in simulation procedure ==================
lm_solver = optx.LevenbergMarquardt(rtol=0.0, atol=1e-6)
@jax.jit
def fit_once(i: int, args: tuple) -> tuple:
"""Loop core, fit simulation data once."""
result, init = args
sim_data = result['data']
# substitute observation data with simulation data
new_data = {
f'{j}_data': sim_data[j][i] for v in data_group.values() for j in v
}
new_residual = jax.jit(handlers.substitute(fn=residual, data=new_data))
new_deviance = jax.jit(handlers.substitute(fn=deviance, data=new_data))
new_sites = jax.jit(handlers.substitute(fn=get_sites, data=new_data))
# fit simulation data
res = optx.least_squares(
fn=lambda p, _: new_residual(p),
solver=lm_solver,
y0=init[i],
max_steps=1024,
throw=False,
)
fitted_params = res.value
grad_norm = jnp.linalg.norm(res.state.f_info.compute_grad())
sites = new_sites(fitted_params)
# update best fit params to result
params = sites['params']
result['params'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
result['params'],
params,
)
# update the best fit model to result
models = sites['models']
result['models'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
result['models'],
{k: models[k] for k in result['models']},
)
# update the deviance information to result
dev = new_deviance(fitted_params)
res_dev = result['deviance']
res_dev['group'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
res_dev['group'],
dev['group'],
)
res_dev['point'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
res_dev['point'],
dev['point'],
)
res_dev['total'] = res_dev['total'].at[i].set(dev['total'])
valid = jnp.bitwise_not(
jnp.isnan(dev['total'])
| jnp.isnan(grad_norm)
| jnp.greater(grad_norm, 1e-3)
)
result['valid'] = result['valid'].at[i].set(valid)
return result, init
def fit_in_sequence(
result: dict,
init: JAXArray,
run_str: str,
progress: bool,
update_rate: int,
):
"""Fit simulation data in sequence."""
n = len(result['valid'])
if progress:
pbar_factory = progress_bar_factory(
n, 1, run_str=run_str, update_rate=update_rate
)
fn = pbar_factory(fit_once)
else:
fn = fit_once
fit_jit = jax.jit(lambda *args: lax.fori_loop(0, n, fn, args)[0])
result = fit_jit(result, init)
return result
def fit_in_parallel(
result: dict,
init: JAXArray,
run_str: str,
progress: bool,
update_rate: int,
n_parallel: int,
) -> dict:
"""Fit simulation data in parallel."""
n = len(result['valid'])
n_parallel = int(n_parallel)
batch = n // n_parallel
if progress:
pbar_factory = progress_bar_factory(
n, n_parallel, run_str=run_str, update_rate=update_rate
)
fn = pbar_factory(fit_once)
else:
fn = fit_once
fit_pmap = jax.pmap(lambda *args: lax.fori_loop(0, batch, fn, args)[0])
reshape = lambda x: x.reshape((n_parallel, -1) + x.shape[1:])
result = fit_pmap(
jax.tree.map(reshape, result),
jax.tree.map(reshape, init),
)
return jax.tree.map(jnp.concatenate, result)
def batch_fit(
init_params: dict[str, JAXArray],
data: dict[str, JAXArray],
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
run_str: str = 'Fitting',
) -> dict:
"""Simulate data and then fit the simulation data.
Parameters
----------
init_params : dict
The initial parameters values in constrained space.
data : dict
The simulation data corresponding to `free_params`.
parallel : bool, optional
Whether to fit in parallel, by default True.
n_parallel : int, optional
The number of parallel processes when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar, by default True.
update_rate : int, optional
The update rate of the progress bar, by default 50.
run_str : str, optional
The string to ahead progress bar during the run when `progress` is
True. The default is 'Fitting'.
Returns
-------
result : dict
The simulation and fitting result.
"""
n_parallel = get_parallel_number(n_parallel)
init_params = jax.tree.map(jnp.array, init_params)
assert set(init_params) == set(free_names)
# check if all params shapes are the same
param_shapes = [np.shape(v)[:-1] for v in init_params.values()]
assert all(i == param_shapes[0] for i in param_shapes)
# the data shape is (nsim, nchan)
data_shapes = [np.shape(data[k])[:-1] for k in ndata if k != 'total']
assert all(i == data_shapes[0] for i in data_shapes)
assert all(len(i) == 1 for i in data_shapes)
nsim = data_shapes[0][0]
# get initial parameters arrays in unconstrained space,
init = jnp.array([init_params[k] for k in free_names]).T
assert init.ndim <= 2
if init.ndim == 2:
assert init.shape[0] == nsim
if init.ndim == 1:
init = jnp.full((nsim, len(init)), init)
init = jax.vmap(constr_arr_to_unconstr_arr)(init)
# fit result container
result = {
'data': data,
'params': {k: jnp.empty(nsim) for k in params_names},
'models': {
i: jnp.empty((nsim, ndata[k]))
for k, v in data_group.items()
for i in [k, *map('{}_model'.format, v)]
},
'deviance': {
'total': jnp.empty(nsim),
'group': {k: jnp.empty(nsim) for k in data_group},
'point': {k: jnp.empty((nsim, ndata[k])) for k in data_group},
},
'valid': jnp.full(nsim, True, bool),
}
# fit simulation data
if parallel:
result = fit_in_parallel(
result, init, run_str, progress, update_rate, n_parallel
)
else:
result = fit_in_sequence(
result, init, run_str, progress, update_rate
)
return result
def simulate_and_fit(
seed: int,
free_params: dict[ParamName, JAXArray],
model_values: dict[str, JAXArray],
n: int = 1,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
run_str: str = 'Fitting',
) -> dict:
"""Simulate data and then fit the simulation data.
Parameters
----------
seed : int
The random number generator seed used for data simulation.
free_params : dict
The free parameters values in unconstrained space.
model_values : dict
The model values corresponding to `free_params`.
n : int, optional
The number of simulations of each model value, by default 1.
parallel : bool, optional
Whether to fit in parallel, by default True.
n_parallel : int, optional
The number of parallel processes when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar, by default True.
update_rate : int, optional
The update rate of the progress bar, by default 50.
run_str : str, optional
The string to ahead progress bar during the run when `progress` is
True. The default is 'Fitting'.
Returns
-------
result : dict
The simulation and fitting result.
"""
model_values = {
f'{k}_model': model_values[f'{k}_model'] for k in simulators
}
# TODO: support posterior prediction with n > 1
# check if all model shapes are the same
shapes = [
np.shape(model_values[f'{k}_Non_model'])[:-1]
for k in ndata
if k != 'total'
]
assert all(i == shapes[0] for i in shapes)
assert not (shapes[0] != () and n > 1)
# simulate data
sim_data = simulate(int(seed), model_values, int(n))
# fit simulation data
return batch_fit(
free_params,
sim_data,
parallel,
n_parallel,
progress,
update_rate,
run_str,
)
# =============== functions used in simulation procedure ==================
return Helper(
ndata=ndata,
nparam=nparam,
dof=dof,
data_names=list(data.keys()),
statistic=stat,
channels=channels,
obs_data=obs_data,
data=dict(data),
model=dict(model),
seed=rng_seed,
sampling_dist=sampling_dist,
numpyro_model=numpyro_model,
params_names={
'free': free_names,
'deterministic': deterministic_names,
'interest': interest_names,
'all': params_names,
},
params_default=unconstr_dic_to_params_dic(
dict(zip(free_names, default_unconstr_arr, strict=True))
),
params_setup=model_info.setup,
params_latex=pname_to_latex,
params_unit=pname_to_unit,
params_log=pname_to_log,
params_comp_latex=pname_to_comp_latex,
free_default=free_default,
get_sites=get_sites,
get_params=get_params,
get_models=get_models,
get_loglike=get_loglike,
get_mle=get_mle,
params_covar=params_covar,
deviance_total=deviance_total,
deviance=deviance,
residual=residual,
constr_arr_to_unconstr_arr=constr_arr_to_unconstr_arr,
constr_dic_to_unconstr_arr=constr_dic_to_unconstr_arr,
unconstr_dic_to_params_dic=unconstr_dic_to_params_dic,
simulate=simulate,
simulate_and_fit=simulate_and_fit,
batch_fit=batch_fit,
)
[docs]
class Helper(NamedTuple):
"""Helper for fitting and analysis."""
ndata: dict[str, int]
"""The number of channels in each dataset and the total number of channels.
"""
nparam: int
"""The number of free parameters in the model."""
dof: int
"""The degree of freedom."""
data_names: list[str]
"""Name of each data."""
statistic: dict[str, Statistic]
"""The statistic used in each dataset."""
channels: dict[str, np.ndarray]
"""Channel information of the datasets."""
obs_data: dict[str, JAXArray]
"""The datasets of observations, including net counts, counts in the "on"
and "off" measurements.
"""
data: dict[str, FixedData]
"""FixedData instances."""
model: dict[str, CompiledModel]
"""Compiled spectral models."""
seed: dict[str, int]
"""Random number generator seed."""
sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]]
"""Sampling distribution of observation data, this is used for probability
integral transform calculation.
"""
numpyro_model: Callable[[bool], None]
"""The numpyro model for spectral fitting."""
params_names: dict
"""The names of parameters in the model."""
params_default: dict[str, JAXFloat]
"""The default values of parameters."""
free_default: dict[str, dict[ParamName, JAXFloat] | JAXArray]
"""The default values of free parameters."""
params_setup: dict[ParamName, tuple[ParamName, ParamSetup]]
"""The mapping from forwarded parameters names to parameters names."""
params_latex: dict[ParamName, str]
"""The LaTeX representation of parameters."""
params_unit: dict[ParamName, str]
"""The unit of parameters."""
params_log: dict[ParamName, bool]
"""Whether the parameters are in log space."""
params_comp_latex: dict[ParamName, str]
"""The LaTeX representation of parameter's component."""
get_sites: Callable[
[JAXArray],
dict[
Literal['params', 'models', 'loglike'],
dict[str, JAXFloat | JAXArray],
],
]
"""Get parameters in constrained space, models values and log likelihood,
given free parameters array in unconstrained space.
"""
get_params: Callable[[Mapping], dict]
"""Get parameters' values in constrained space given numpyro model sites.
"""
get_models: Callable[[Mapping], dict]
"""Get model values given numpyro model sites."""
get_loglike: Callable[[Mapping], dict]
"""Get log likelihood given numpyro model sites."""
get_mle: Callable[[JAXArray], tuple[JAXArray, JAXArray]]
"""Get the MLE and error of all parameters in constrained space,
given MLE of free parameters in unconstrained space.
"""
params_covar: Callable[[JAXArray, JAXArray], JAXArray]
"""Calculate covariance matrix of all parameters in constrained space,
given values and covariance matrix of free parameters in unconstrained
space.
"""
deviance_total: Callable[[JAXArray], JAXFloat]
"""Calculate total deviance given free parameters array in unconstrained
space.
"""
deviance: Callable[[JAXArray], dict[str, JAXArray]]
"""Calculate total, group and point deviance given free parameters array in
unconstrained space.
"""
residual: Callable[[JAXArray], JAXArray]
"""Calculate deviance residual (i.e., sqrt deviance) given free parameters
array in unconstrained space.
"""
constr_arr_to_unconstr_arr: Callable[[JAXArray], JAXArray]
"""Covert free parameters array from constrained space into
unconstrained space.
"""
constr_dic_to_unconstr_arr: Callable[[ParamNameValMapping], JAXArray]
"""Covert free parameters dict from constrained space to array in
unconstrained space.
"""
unconstr_dic_to_params_dic: Callable[
[ParamNameValMapping], ParamNameValMapping
]
"""Get parameters dict in constrained space, given a free parameters dict
in unconstrained space.
"""
simulate: Callable[[int, dict[str, JAXArray], int], dict[str, JAXArray]]
"""Function to simulate data."""
simulate_and_fit: Callable[
[int, dict, dict, int, bool, int, bool, int, str], dict
]
"""Function to simulate data and then fit the simulation data."""
batch_fit: Callable[
[dict[str, JAXArray], dict[str, JAXArray], bool, int, bool, int, str],
dict,
]
"""Function to fit simulation data."""