elisa.infer.helper#

Helper for fitting and analysis.

check_params(params: str | Sequence[str] | None, helper: Helper) list[str][source]#
get_helper(fit: Any) Helper[source]#

Get helper functions for fitting.

class Helper(ndata: dict[str, int], nparam: int, dof: int, data_names: list[str], statistic: dict[str, Statistic], channels: dict[str, np.ndarray], obs_data: dict[str, JAXArray], data: dict[str, FixedData], model: dict[str, CompiledModel], seed: dict[str, int], sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]], numpyro_model: Callable[[bool], None], params_names: dict, params_default: dict[str, JAXFloat], free_default: dict[str, dict[ParamName, JAXFloat] | JAXArray], params_setup: dict[ParamName, tuple[ParamName, ParamSetup]], params_latex: dict[ParamName, str], params_unit: dict[ParamName, str], params_log: dict[ParamName, bool], params_comp_latex: dict[ParamName, str], get_sites: Callable[[JAXArray], dict[Literal['params', 'models', 'loglike'], dict[str, JAXFloat | JAXArray]]], get_params: Callable[[Mapping], dict], get_models: Callable[[Mapping], dict], get_loglike: Callable[[Mapping], dict], get_mle: Callable[[JAXArray], tuple[JAXArray, JAXArray]], params_covar: Callable[[JAXArray, JAXArray], JAXArray], deviance_total: Callable[[JAXArray], JAXFloat], deviance: Callable[[JAXArray], dict[str, JAXArray]], residual: Callable[[JAXArray], JAXArray], constr_arr_to_unconstr_arr: Callable[[JAXArray], JAXArray], constr_dic_to_unconstr_arr: Callable[[ParamNameValMapping], JAXArray], unconstr_dic_to_params_dic: Callable[[ParamNameValMapping], ParamNameValMapping], simulate: Callable[[int, dict[str, JAXArray], int], dict[str, JAXArray]], simulate_and_fit: Callable[[int, dict, dict, int, bool, int, bool, int, str], dict], batch_fit: Callable[[dict[str, JAXArray], dict[str, JAXArray], bool, int, bool, int, str], dict])[source]#

Bases: NamedTuple

Helper for fitting and analysis.

Methods

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

ndata: dict[str, int]#

The number of channels in each dataset and the total number of channels.

nparam: int#

The number of free parameters in the model.

dof: int#

The degree of freedom.

data_names: list[str]#

Name of each data.

statistic: dict[str, Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat']]#

The statistic used in each dataset.

channels: dict[str, ndarray]#

Channel information of the datasets.

obs_data: dict[str, Array]#

The datasets of observations, including net counts, counts in the “on” and “off” measurements.

data: dict[str, FixedData]#

FixedData instances.

model: dict[str, CompiledModel]#

Compiled spectral models.

seed: dict[str, int]#

Random number generator seed.

sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]]#

Sampling distribution of observation data, this is used for probability integral transform calculation.

numpyro_model: Callable[[bool], None]#

The numpyro model for spectral fitting.

params_names: dict#

The names of parameters in the model.

params_default: dict[str, Array]#

The default values of parameters.

free_default: dict[str, dict[str, Array] | Array]#

The default values of free parameters.

params_setup: dict[str, tuple[str, ParamSetup]]#

The mapping from forwarded parameters names to parameters names.

params_latex: dict[str, str]#

The LaTeX representation of parameters.

params_unit: dict[str, str]#

The unit of parameters.

params_log: dict[str, bool]#

Whether the parameters are in log space.

params_comp_latex: dict[str, str]#

The LaTeX representation of parameter’s component.

get_sites: Callable[[Array], dict[Literal['params', 'models', 'loglike'], dict[str, Array]]]#

Get parameters in constrained space, models values and log likelihood, given free parameters array in unconstrained space.

get_params: Callable[[Mapping], dict]#

Get parameters’ values in constrained space given numpyro model sites.

get_models: Callable[[Mapping], dict]#

Get model values given numpyro model sites.

get_loglike: Callable[[Mapping], dict]#

Get log likelihood given numpyro model sites.

get_mle: Callable[[Array], tuple[Array, Array]]#

Get the MLE and error of all parameters in constrained space, given MLE of free parameters in unconstrained space.

params_covar: Callable[[Array, Array], Array]#

Calculate covariance matrix of all parameters in constrained space, given values and covariance matrix of free parameters in unconstrained space.

deviance_total: Callable[[Array], Array]#

Calculate total deviance given free parameters array in unconstrained space.

deviance: Callable[[Array], dict[str, Array]]#

Calculate total, group and point deviance given free parameters array in unconstrained space.

residual: Callable[[Array], Array]#

Calculate deviance residual (i.e., sqrt deviance) given free parameters array in unconstrained space.

constr_arr_to_unconstr_arr: Callable[[Array], Array]#

Covert free parameters array from constrained space into unconstrained space.

constr_dic_to_unconstr_arr: Callable[[dict[str, Array]], Array]#

Covert free parameters dict from constrained space to array in unconstrained space.

unconstr_dic_to_params_dic: Callable[[dict[str, Array]], dict[str, Array]]#

Get parameters dict in constrained space, given a free parameters dict in unconstrained space.

simulate: Callable[[int, dict[str, Array], int], dict[str, Array]]#

Function to simulate data.

simulate_and_fit: Callable[[int, dict, dict, int, bool, int, bool, int, str], dict]#

Function to simulate data and then fit the simulation data.

batch_fit: Callable[[dict[str, Array], dict[str, Array], bool, int, bool, int, str], dict]#

Function to fit simulation data.