Source code for elisa.infer.fit

"""Model fit in a maximum likelihood or Bayesian way."""

from __future__ import annotations

import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from importlib import metadata
from typing import TYPE_CHECKING

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx
import xarray as xr
from iminuit import Minuit
from numpyro.infer import init_to_value
from numpyro.infer.barker import BarkerMH, BarkerMHState
from numpyro.infer.ensemble import (
    AIES,
    AIESState,
    EnsembleSamplerState,
    ESSState,
)
from numpyro.infer.hmc import HMC, NUTS, HMCState
from numpyro.infer.mcmc import MCMC, MCMCKernel
from numpyro.infer.sa import SA, SAState

from elisa import __version__ as elisa_version
from elisa.data.base import FixedData, ObservationData
from elisa.infer.helper import Helper, get_helper
from elisa.infer.likelihood import _STATISTIC_OPTIONS
from elisa.infer.results import MLEResult, PosteriorResult
from elisa.infer.samplers.blackjax.nuts import BlackJAXNUTS, BlackJAXNUTSState
from elisa.infer.samplers.ensemble.emcee import EmceeSampler
from elisa.infer.samplers.ensemble.numpyro import (
    NumPyroAIES,
    NumpyroEnsembleSampler,
    NumPyroESS,
)
from elisa.infer.samplers.ensemble.zeus import ZeusSampler
from elisa.infer.samplers.ns.jaxns import JAXNSSampler
from elisa.models.model import Model, get_model_info
from elisa.util.config import get_parallel_number
from elisa.util.misc import (
    add_suffix,
    build_namespace,
    make_pretty_table,
)

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any, Literal

    from jaxlib.xla_client import Device
    from prettytable import PrettyTable

    from elisa.infer.likelihood import Statistic
    from elisa.models.model import ModelInfo
    from elisa.util.typing import Array, ArrayLike, JAXArray, JAXFloat


[docs] class Fit(ABC): """Abstract base class for model fitting. Parameters ---------- data : Data or sequence of Data The observation data. model : Model or sequence of Model The model used to fit the data. stat : {'chi2', 'cstat', 'pstat', 'pgstat', 'wstat'} or sequence, optional The likelihood option for the data and model. Available likelihood options are: * ``'chi2'``: Gaussian data * ``'cstat'``: Poisson data * ``'pstat'``: Poisson data with known background * ``'pgstat'``: Poisson data with Gaussian background * ``'wstat'``: Poisson data with Poisson background The default is None, which means automatically choosing the suitable likelihood options for the datasets and models. seed : int, optional Seed of random number generator used for fit. The default is 42. """ # TODO: # - fit multiple sources to one dataset (with multiple responses) # - fit data background given response and model _lm: Callable[[JAXArray], JAXArray] | None = None _ns: JAXNSSampler | None = None def __init__( self, data: ObservationData | Sequence[ObservationData], model: Model | Sequence[Model], stat: Statistic | Sequence[Statistic] | None = None, seed: int = 42, ): inputs = self._parse_input(data, model, stat) data: list[FixedData] = inputs[0] models: list[Model] = inputs[1] stats: list[Statistic] = inputs[2] # if a component is not fit with all datasets, # add names of data sets to be fit with it as its name/latex suffix data_names = [d.name for d in data] data_to_cid = { n: m._comps_id for n, m in zip(data_names, models, strict=True) } cid_to_comp = {c._id: c for m in models for c in m._comps} cid = list(cid_to_comp.keys()) comps = list(cid_to_comp.values()) cid_to_data_suffix = { i: ( '+'.join(i for i in data_names if i not in names) # keep order if ( names := [n for n in data_names if i not in data_to_cid[n]] ) else '' ) for i in cid } data_suffix = list(cid_to_data_suffix.values()) cname = [comp.name for comp in comps] name_with_data_suffix = list( map(''.join, zip(cname, data_suffix, strict=True)) ) num_suffix = build_namespace(name_with_data_suffix)['suffix_num'] cname = add_suffix(cname, num_suffix, True) cname = add_suffix(cname, data_suffix, False) cid_to_name = dict(zip(cid, cname, strict=True)) latex = [comp.latex for comp in comps] latex = add_suffix(latex, num_suffix, True, latex=True) latex = add_suffix(latex, data_suffix, False, latex=True, mathrm=True) cid_to_latex = dict(zip(cid, latex, strict=True)) # get model info self._model_info: ModelInfo = get_model_info( comps, cid_to_name, cid_to_latex ) # first filter out duplicated models then compile the remaining models, # this is intended to avoid re-compilation of the same model models_id = [id(m) for m in models] mid_to_model = dict(zip(models_id, models, strict=True)) compiled_model = { mid: m.compile(model_info=self._model_info) for mid, m in mid_to_model.items() } data_to_mid = dict(zip(data_names, models_id, strict=True)) self._model = { name: compiled_model[mid] for name, mid in data_to_mid.items() } # store data, stat, seed self._data: dict[str, FixedData] = dict( zip(data_names, data, strict=True) ) self._stat: dict[str, Statistic] = dict( zip(data_names, stats, strict=True) ) self._seed: int = int(seed) # make model information table self._make_info_table() self.__helper: Helper | None = None def _optimize_lm( self, unconstr_init: JAXArray, max_steps: int = 131072, throw: bool = True, verbose: bool = False, ) -> tuple[JAXArray, JAXFloat]: """Search MLE by Levenberg-Marquardt algorithm of :mod:`optimistix`.""" if verbose: verbose = frozenset({'step', 'loss'}) else: verbose = frozenset() if getattr(self, '_lm_verbose', verbose) != verbose: self._lm = None if self._lm is None: lm_solver = optx.LevenbergMarquardt( rtol=0.0, atol=1e-6, verbose=verbose ) residual = jax.jit(lambda x, aux: self._helper.residual(x)) def lm(init): res = optx.least_squares( fn=residual, solver=lm_solver, y0=init, max_steps=max_steps, throw=throw, ) grad_norm = jnp.linalg.norm(res.state.f_info.compute_grad()) return res.value, grad_norm self._lm = jax.jit(lm) return self._lm(jnp.asarray(unconstr_init, float)) def _optimize_ns(self, max_steps=131072, verbose=False) -> JAXArray: """Search MLE using nested sampling of :mod:`jaxns`.""" if self._ns is None: self._ns = JAXNSSampler( self._helper.numpyro_model, constructor_kwargs={ 'max_samples': max_steps, 'parameter_estimation': True, 'verbose': verbose, }, ) t0 = time.time() print('Start searching MLE...') self._ns.run(rng_key=jax.random.PRNGKey(self._helper.seed['mcmc'])) print(f'Search completed in {time.time() - t0:.2f} s') ns = self._ns samples = ns._results.samples loglike = [ samples[f'{i}_loglike'].sum(axis=-1) for i in self._data.keys() ] mle_idx = np.sum(loglike, axis=0).argmax() mle = jax.tree.map(lambda s: s[mle_idx], samples) mle = {i: mle[i] for i in self._helper.params_names['free']} return self._helper.constr_dic_to_unconstr_arr(mle) @property def _helper(self) -> Helper: if self.__helper is None: self.__helper = get_helper(self) return self.__helper
[docs] def summary(self, file=None) -> None: """Print the summary of fitting setup. Parameters ---------- file : file-like An object with a ``write(string)`` method. This is passed to :py:func:`print`. """ print(repr(self), file=file)
@property @abstractmethod def _tab_config(self) -> tuple[str, frozenset[str]]: """Model information table's title and excluded table fields.""" pass def __repr__(self) -> str: return ( f'\n{self._tab_config[0]}\n\n' f'{self._tab_likelihood.get_string()}\n\n' f'{self._tab_params.get_string()}\n' ) def _repr_html_(self) -> str: """The repr in Jupyter notebook environment.""" return ( f'<details open><summary><b>{self._tab_config[0]}</b></summary>' f'<br/>{self._tab_likelihood.get_html_string(format=True)}' f'<br/>{self._tab_params.get_html_string(format=True)}' '</details>' ) def _make_info_table(self): fields = ('Data', 'Model', 'Statistic') rows = tuple( zip( self._data, (m.name for m in self._model.values()), self._stat.values(), strict=True, ) ) self._tab_likelihood: PrettyTable = make_pretty_table(fields, rows) fields = ('No.', 'Component', 'Parameter', 'Value', 'Bound', 'Prior') mask = np.isin(fields, tuple(self._tab_config[1])) fields = np.array(fields)[~mask].tolist() rows = np.array(self._model_info.info)[:, ~mask].tolist() self._tab_params: PrettyTable = make_pretty_table(fields, rows) @staticmethod def _parse_input( data: ObservationData | Sequence[ObservationData], model: Model | Sequence[Model], stat: Statistic | Sequence[Statistic] | None, ) -> tuple[list[FixedData], list[Model], list[Statistic]]: """Check if data, model, and stat are correct and return lists.""" # ====================== some helper functions ======================== def get_list( inputs: Any, name: str, expect_type, type_name: str ) -> list: """Check the model/data/stat, and return a list.""" if isinstance(inputs, expect_type): input_list = [inputs] elif isinstance(inputs, Sequence): if not inputs: raise ValueError(f'{name} list is empty') if not all(isinstance(i, expect_type) for i in inputs): raise ValueError(f'all {name} must be a valid {type_name}') input_list = list(inputs) else: raise ValueError(f'got wrong type {type(inputs)} for {name}') return input_list def get_stat(d: FixedData) -> Statistic: """Get the default stat for the data.""" # 'pstat' is used only when specified explicitly by user if d.spec_poisson: if d.has_back: if d.back_poisson: return 'wstat' return 'pgstat' else: return 'cstat' else: return 'chi2' def check_stat(d: FixedData, s: Statistic): """Check if data type and likelihood are matched.""" name = d.name if not d.spec_poisson and s != 'chi2': raise ValueError( f'{name} data has Gaussian uncertainties, ' 'use Gaussian statistic (chi2) instead' ) if s == 'chi2': if np.any(d.net_errors == 0.0): raise ValueError( f'{name} data has zero uncertainties, ' 'and Gaussian statistic (chi2) will be invalid; ' 'grouping the data may fix this error' ) elif s == 'cstat': if d.has_back: back = 'Poisson' if d.back_poisson else 'Gaussian' stat1 = 'W' if d.back_poisson else 'PG' stat2 = 'w' if d.back_poisson else 'pg' raise ValueError( f'{name} data has {back} background, ' 'and using C-statistic (cstat) is invalid; ' f'use {stat1}-statistic ({stat2}stat) instead' ) elif s == 'pstat': if not d.has_back: raise ValueError( f'{name} data has no background, ' 'and using P-statistic (pstat) is invalid; ' 'use C-statistic (cstat) instead' ) elif s == 'pgstat': if not d.has_back: raise ValueError( f'{name} data has no background, ' 'and using PG-statistic (pgstat) is invalid; ' 'use C-statistic (cstat) instead' ) if np.any(d.back_errors == 0.0): raise ValueError( f'{name} data has zero background uncertainties, ' 'and PG-statistic (pgstat) will be invalid; ' 'grouping the data may fix this error' ) elif s == 'wstat' and not (d.has_back and d.back_poisson): if not d.has_back: raise ValueError( f'{name} data has no background, ' 'and using W-statistic (wstat) is invalid; ' 'use C-statistic (cstat) instead' ) if not d.back_poisson: raise ValueError( f'{name} data has Gaussian background, ' 'and using W-statistic (wstat) is invalid; ' 'use PG-statistic (pgstat) instead' ) # ====================== some helper functions ======================== # get data data_list: list[FixedData] = [ d.get_fixed_data() for d in get_list(data, 'data', ObservationData, 'Data') ] # check if data are used multiple times if len(list(map(id, data_list))) != len(data_list): count = {d: data_list.count(d) for d in set(data_list)} raise ValueError( 'data cannot be used multiple times: ' + ', '.join( f'{k.name} ({v})' for k, v in count.items() if v > 1 ) ) # check if data name is unique name_list = [d.name for d in data_list] if len(set(name_list)) != len(data_list): raise ValueError( f'data names are not unique: {", ".join(set(name_list))}, ' "please give a unique name in Data(..., name='NAME'), " "or set data.name='NAME'" ) # get model model_list: list[Model] = get_list(model, 'model', Model, 'Model') # check if the model type is additive flag = [i.type == 'add' for i in model_list] if not all(flag): err = (j for i, j in enumerate(model_list) if not flag[i]) err = ', '.join(f"'{i}'" for i in err) msg = f'got models which are not additive type: {err}' raise TypeError(msg) # get stat stat_list: list[Statistic] if stat is None: stat_list: list[Statistic] = [get_stat(d) for d in data_list] else: stat_list: list[Statistic] = get_list(stat, 'stat', str, 'str') # check the stat option flag = [i in _STATISTIC_OPTIONS for i in stat_list] if not all(flag): err = ', '.join( f"'{j}'" for i, j in enumerate(stat_list) if not flag[i] ) supported = ', '.join(f"'{i}'" for i in _STATISTIC_OPTIONS) msg = f'unexpected stat: {err}; supported are {supported}' raise ValueError(msg) nd = len(data_list) nm = len(model_list) ns = len(stat_list) # check model number if nm == 1: model_list *= nd elif nm != nd: msg = f'number of model ({nm}) and data ({nd}) are not matched' raise ValueError(msg) # check stat number if ns == 1: stat_list *= nd elif ns != nd: msg = f'number of data ({nd}) and stat ({ns}) are not matched' raise ValueError(msg) # check if correctly using stat for d, s in zip(data_list, stat_list, strict=True): check_stat(d, s) return data_list, model_list, stat_list
[docs] class MaxLikeFit(Fit): _tab_config = ('Maximum Likelihood Fit', frozenset({'Prior'})) def _optimize_minuit( self, unconstr_init: JAXArray, ncall: int | None = None, throw: bool = True, verbose: int | bool = False, ) -> Minuit: """Search MLE using Minuit algorithm of :mod:`iminuit`.""" deviance = jax.jit(self._helper.deviance_total) deviance.ndata = self._helper.ndata['total'] minuit = Minuit( deviance, np.array(unconstr_init), grad=jax.jit(jax.grad(deviance)), name=self._helper.params_names['free'], ) if throw: minuit.throw_nan = True minuit.print_level = int(verbose) # TODO: test if simplex can be used to "polish" the initial guess minuit.strategy = 2 minuit.migrad(ncall=ncall, iterate=10) return minuit
[docs] def mle( self, init: ArrayLike | dict | None = None, method: Literal['minuit', 'lm', 'ns'] = 'minuit', max_steps: int = None, throw: bool = True, verbose: int | bool = False, ) -> MLEResult: """Search Maximum Likelihood Estimation (MLE) for the model. Parameters ---------- init : dict, optional Initial guess for the maximum likelihood estimation. method : {'minuit', 'lm', 'ns'}, optional Optimization algorithm used to find the MLE. Available options are: * ``'minuit'``: Migrad algorithm of :mod:`iminuit`. * ``'lm'``: Levenberg-Marquardt algorithm of :mod:`optimistix`. * ``'ns'``: Nested sampling of :mod:`jaxns`. This option first search MLE globally, then polish it with local minimization. The default is 'minuit'. Other Parameters ---------------- max_steps : int, optional The maximum number of steps the solver can take. The default is 131072. throw : bool, optional Whether to report any failures of the solver. Defaults to True. verbose : int or bool, optional Whether to print fit progress information. The default is False. Returns ------- MLEResult The MLE result. """ if init is None: init = self._helper.free_default['constr_dic'] else: init = self._helper.free_default['constr_dic'] | dict(init) init_unconstr = self._helper.constr_dic_to_unconstr_arr(init) max_steps = 131072 if max_steps is None else int(max_steps) if method == 'lm': # use Levenberg-Marquardt algorithm to find MLE init_unconstr, _ = self._optimize_lm( init_unconstr, max_steps, throw, bool(verbose) ) elif method == 'ns': # use nested sampling to find MLE init_unconstr = self._optimize_ns(max_steps, verbose) else: if method != 'minuit': raise ValueError(f'unsupported optimization method {method}') minuit = self._optimize_minuit( init_unconstr, max_steps, throw, verbose ) return MLEResult(minuit, self._helper)
[docs] class BayesFit(Fit): _tab_config = ('Bayesian Fit', frozenset({'Bound'})) def _generate_results( self, samples: dict[str, Array], ess: dict[str, int], reff: float, lnZ: tuple[float | None, float | None] = (None, None), sample_stats: dict[str, Any] | None = None, sampler_state: Any | None = None, attrs: dict[str, Any] | None = None, inference_library: str | None = None, ) -> PosteriorResult: helper = self._helper samples = jax.device_get(samples) params = helper.get_params(samples) models = helper.get_models(samples) posterior = params | models posterior_predictive = helper.simulate(helper.seed['pred'], models, 1) loglike = helper.get_loglike(samples) group = {f'{k}_total': v for k, v in loglike['group'].items()} loglike = ( loglike['data'] | loglike['point'] | group | {'channels': loglike['channels']} | {'total': loglike['total']} ) # get observation counts data obs_data = helper.obs_data # coords and dims of arviz.InferenceData coords = dict(helper.channels) dims = {'channels': ['channel']} for i in helper.data_names: dim = [f'{i}_channel'] dims[i] = dims[f'{i}_Non'] = dims[f'{i}_Non_model'] = dim if f'{i}_Noff' in obs_data: dims[f'{i}_Noff'] = dims[f'{i}_Noff_model'] = dim # additional attrs for each group of arviz.InferenceData if attrs is None: attrs = {} else: attrs = dict(attrs) attrs['elisa_version'] = elisa_version attrs |= { 'inference_library': inference_library, 'inference_library_version': metadata.version(inference_library), } # create InferenceData idata = az.from_dict( posterior=posterior, posterior_predictive=posterior_predictive, sample_stats=sample_stats, log_likelihood=loglike, observed_data=obs_data, coords=coords, dims=dims, posterior_attrs=attrs, posterior_predictive_attrs=attrs, sample_stats_attrs=attrs, log_likelihood_attrs=attrs, observed_data_attrs=attrs, ) # add extra statistics to idata ess = ess | {'reff': reff} evidence = { 'lnZ': lnZ[0] if lnZ[0] is not None else np.nan, 'lnZ_error': lnZ[1] if lnZ[1] is not None else np.nan, } ess = xr.Dataset(ess, attrs=attrs) evidence = xr.Dataset(evidence, attrs=attrs) idata.add_groups( group_dict={'ess': ess, 'evidence': evidence}, warn_on_custom_groups=False, ) return PosteriorResult( helper=self._helper, idata=idata, ml_optimize=self._optimize_lm, sampler_state=sampler_state, ) def _check_init(self, init: dict[str, float] | None) -> dict[str, float]: if init is None: init = self._helper.free_default['constr_dic'] else: init = self._helper.free_default['constr_dic'] | dict(init) return init @staticmethod def _set_numpyro_mcmc_post_warmup_state(mcmc: MCMC, state: Any) -> None: if state is None: return assert isinstance(mcmc, MCMC) kernel: MCMCKernel = mcmc.sampler kernel_state_types = { BarkerMH: BarkerMHState, BlackJAXNUTS: BlackJAXNUTSState, NumpyroEnsembleSampler: EnsembleSamplerState, HMC: HMCState, SA: SAState, NUTS: HMCState, } ensemble_state_types = { NumPyroAIES: AIESState, NumPyroESS: ESSState, } for kt, st in kernel_state_types.items(): if isinstance(kernel, kt): if not isinstance(state, st): raise ValueError( f'post_warmup_state must be {st.__name__}' ) break if isinstance(kernel, NumpyroEnsembleSampler): kernel_type = kernel.__class__ is_type = ensemble_state_types.get(kernel_type, object) if not isinstance(state.inner_state, is_type): raise ValueError( f'post_warmup_state must be state for {kernel_type}' ) mcmc.post_warmup_state = state def _get_ess( self, samples: dict[str, Array], chains: int, ) -> tuple[dict[str, int], float]: helper = self._helper params_names = helper.params_names # effective sample size params = helper.get_params(samples) ess = az.ess(params) ess = {k: int(ess[k].values) for k in params.keys()} # relative mcmc efficiency # the calculation of reff is according to arviz loo: # https://github.com/arviz-devs/arviz/blob/1b0b9cb050e3b757e1551d3a1f7a8f8e2773bc36/arviz/stats/stats.py#L776 if chains == 1: reff = 1.0 else: # use only free parameters to calculate reff free = {k: params[k] for k in params_names['free']} reff_p = az.ess(free, method='mean', relative=True) reff = np.hstack(list(reff_p.data_vars.values())).mean() return ess, reff def _generate_result_from_numpyro( self, sampler: MCMC, kernel_library: str = 'numpyro', ) -> PosteriorResult: samples = sampler.get_samples(group_by_chain=True) if isinstance(sampler.sampler, NumpyroEnsembleSampler): samples = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), samples) samples = jax.tree.map( lambda x: jnp.reshape( x, (x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]), ), samples, ) # sample stats rename = {'num_steps': 'n_steps'} sample_stats = {} for k, v in sampler.get_extra_fields(group_by_chain=True).items(): name = rename.get(k, k) value = jax.device_get(v).copy() sample_stats[name] = value if 'tree_depth' not in sample_stats and 'num_steps' in sample_stats: num_steps = sample_stats['num_steps'] sample_stats['tree_depth'] = np.log2(num_steps).astype(int) + 1 ess, reff = self._get_ess(samples, sampler.num_chains) return self._generate_results( samples=samples, ess=ess, reff=reff, sample_stats=sample_stats, sampler_state=sampler.last_state, inference_library=kernel_library, ) def _run_numpyro_mcmc( self, kernel: type[MCMCKernel], warmup: int, steps: int, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, post_warmup_state: Any = None, extra_fields: tuple[str, ...] = (), kernel_library: str = 'numpyro', **kernel_kwargs: dict, ): """Run the regular sampler of numpyro.""" if not issubclass(kernel, MCMCKernel): raise ValueError('kernel must be a subclass of numpyro MCMCKernel') warmup = int(warmup) steps = int(steps) thinning = int(thinning) device_count = jax.local_device_count() chains = int(chains) if chains is not None else device_count kernel_kwargs['model'] = self._helper.numpyro_model rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) # TODO: option to let sampler starting from MLE if issubclass(kernel, NumpyroEnsembleSampler): # set randomize_split to True to improve mixing kernel_kwargs.setdefault('randomize_split', True) init = self._check_init(init) init = self._helper.constr_dic_to_unconstr_arr(init) n_params = len(init) default_walkers = 4 * n_params walkers = kernel_kwargs.pop('walkers', default_walkers) if walkers is None: walkers = default_walkers else: walkers = int(walkers) kernel_kwargs['walkers'] = walkers jitter = 0.1 * jnp.abs(init) low = init - jitter high = init + jitter rng_key, init_key = jax.random.split(rng_key, 2) init = jax.random.uniform( init_key, shape=(n_params, chains, walkers), minval=low[:, None, None], maxval=high[:, None, None], ) if chains == 1: # remove the chains dim if run a single sampler init = jnp.squeeze(init, axis=1) init = dict( zip(self._helper.params_names['free'], init, strict=True) ) else: init_strategy = init_to_value(values=self._check_init(init)) if init is not None: kernel_kwargs['init_strategy'] = init_strategy init = None else: kernel_kwargs.setdefault('init_strategy', init_strategy) sampler = MCMC( kernel(**kernel_kwargs), num_warmup=warmup, num_samples=steps * thinning, num_chains=chains, thinning=thinning, chain_method=chain_method, progress_bar=progress, ) self._set_numpyro_mcmc_post_warmup_state(sampler, post_warmup_state) sampler.run(rng_key, extra_fields=extra_fields, init_params=init) return self._generate_result_from_numpyro( sampler=sampler, kernel_library=kernel_library, )
[docs] def nuts( self, warmup: int = 2000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, post_warmup_state: HMCState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`numpyro`'s implementation of No-U-Turn Sampler (NUTS). .. note:: If the chains are not converged well, see ref [2]_ for more information on how to fine-tune NUTS. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 2000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. If there are not enough devices available, chains will run in sequence. Defaults to the number of ``jax.local_device_count()``. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : HMCState, optional The state before the sampling phase. The sampling will start from the given state if provided. **kwargs : dict Extra parameters passed to :class:`numpyro.infer.NUTS`. The default for `dense_mass` is ``True``. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo (https://www.jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf) .. [2] NumPyro tutorial: `Bad posterior geometry and how to deal with it <https://num.pyro.ai/en/stable/tutorials/bad_posterior_geometry.html>`__ """ kwargs.setdefault('dense_mass', True) return self._run_numpyro_mcmc( kernel=NUTS, warmup=warmup, steps=steps, chains=chains, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, extra_fields=('energy', 'num_steps'), **kwargs, )
[docs] def barkermh( self, warmup: int = 5000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, post_warmup_state: BarkerMHState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`numpyro`'s implementation of ``BarkerMH`` sampler. .. note:: This is a gradient-based MCMC algorithm of Metropolis-Hastings type that uses a skew-symmetric proposal distribution that depends on the gradient of the potential (the Barker proposal [1]_). In particular the proposal distribution is skewed in the direction of the gradient at the current sample. This algorithm is expected to be particularly effective for low to moderate dimensional models, where it may be competitive with HMC and NUTS. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 5000. steps : int, optional Number of steps to run for each chain. The default is 10000. chains : int, optional Number of MCMC chains to run. If there are not enough devices available, chains will run in sequence. Defaults to the number of ``jax.local_device_count()``. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : BarkerMHState, optional The state before the sampling phase. The sampling will start from the given state if provided. **kwargs : dict Extra parameters passed to :class:`numpyro.infer.BarkerMH`. The default for `dense_mass` is ``True``. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] The Barker proposal: combining robustness and efficiency in gradient-based MCMC (https://doi.org/10.1111/rssb.12482), Samuel Livingstone and Giacomo Zanella. """ kwargs.setdefault('dense_mass', True) return self._run_numpyro_mcmc( kernel=BarkerMH, warmup=warmup, steps=steps, chains=chains, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, **kwargs, )
[docs] def sa( self, warmup: int = 70000, steps: int = 5000, chains: int | None = None, thinning: int = 2, init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, post_warmup_state: SAState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`numpyro`'s implementation of Sample Adaptive (SA) MCMC. .. note:: This is a gradient-free sampler. It is fast in terms of n_eff / s, but requires **many** warmup (burn-in) steps. If the result does not converge satisfactorily, consider increasing the values of `warmup` and/or `adapt_state_size`, or providing better initial parameter estimates via the `init` argument. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 70000. steps : int, optional Number of steps to run. The default is 5000. chains : int, optional Number of MCMC chains to run. If there are not enough devices available, chains will run in sequence. Defaults to the number of ``jax.local_device_count()``. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 2. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : SAState, optional The state before the sampling phase. The sampling will start from the given state if provided. **kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. The default for `adapt_state_size` is ``5 * D``, where `D` is the dimension of model parameters. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc), Michael Zhu """ nparams = len(self._helper.params_names['free']) kwargs.setdefault('adapt_state_size', 5 * nparams) return self._run_numpyro_mcmc( kernel=SA, warmup=warmup, steps=steps, chains=chains, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, **kwargs, )
[docs] def blackjax_nuts( self, warmup: int = 2000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, post_warmup_state: HMCState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`blackjax`'s implementation of No-U-Turn Sampler (NUTS). .. note:: If the chains are not converged well, see ref [2]_ for more information on how to fine-tune NUTS. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 2000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. If there are not enough devices available, chains will run in sequence. Defaults to the number of ``jax.local_device_count()``. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : HMCState, optional The state before the sampling phase. The sampling will start from the given state if provided. **kwargs : dict Extra parameters passed to :class:`BlackJAXNUTS`. The default for `dense_mass` is ``True``. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo (https://www.jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf) .. [2] NumPyro tutorial: `Bad posterior geometry and how to deal with it <https://num.pyro.ai/en/stable/tutorials/bad_posterior_geometry.html>`__ """ return self._run_numpyro_mcmc( kernel=BlackJAXNUTS, warmup=warmup, steps=steps, chains=chains, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, extra_fields=('energy', 'num_steps'), kernel_library='blackjax', **kwargs, )
[docs] def aies( self, warmup: int = 5000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', n_parallel: int | None = None, progress: bool = True, post_warmup_state: EnsembleSamplerState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`numpyro`'s Affine-Invariant Ensemble Sampling (AIES). Affine-invariant ensemble sampling [1]_ is a gradient-free method that informs Metropolis-Hastings proposals by sharing information between chains. Suitable for low to moderate dimensional models. Generally, `chains` should be at least twice the dimensionality of the model. .. note:: This sampler must be used with even number `chains` > 1. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 5000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is the dimension of model parameters. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. n_parallel : int, optional Number of parallel samplers to run. The default is ``jax.local_device_count()``. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : EnsembleSamplerState, optional The state before the sampling phase. The sampling will start from the given state if provided. This does not take effect when `n_parallel`>=2. **kwargs : dict Extra parameters passed to :class:`numpyro.infer.AIES`. The default for `moves` is ``{AIES.StretchMove(): 1.0}``. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067), Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. """ kwargs['walkers'] = chains # use the same default moves as in emcee kwargs.setdefault('moves', {AIES.StretchMove(): 1.0}) return self._run_numpyro_mcmc( kernel=NumPyroAIES, warmup=warmup, steps=steps, chains=n_parallel, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, **kwargs, )
[docs] def ess( self, warmup: int = 5000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, chain_method: str = 'parallel', n_parallel: int | None = None, progress: bool = True, post_warmup_state: EnsembleSamplerState | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`numpyro`'s Ensemble Slice Sampling (ESS). Ensemble slice sampling [1]_ is a gradient free method that finds better slice sampling directions by sharing information between chains. Suitable for low to moderate dimensional models. Generally, `chains` should be at least twice the dimensionality of the model. .. note:: This sampler must be used with even number `chains` > 1. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 5000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is the dimension of model parameters. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.infer.MCMC`. n_parallel : int, optional Number of parallel samplers to run. The default is ``jax.local_device_count()``. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : EnsembleSamplerState, optional The state before the sampling phase. The sampling will start from the given state if provided. This does not take effect when `n_parallel`>=2. **kwargs : dict Extra parameters passed to :class:`numpyro.infer.ESS`. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference (https://academic.oup.com/mnras/article/508/3/3589/6381726), Minas Karamanis, Florian Beutler, and John A. Peacock. .. [2] Ensemble slice sampling (https://link.springer.com/article/10.1007/s11222-021-10038-2), Minas Karamanis, Florian Beutler. """ kwargs['walkers'] = chains return self._run_numpyro_mcmc( kernel=NumPyroESS, warmup=warmup, steps=steps, chains=n_parallel, thinning=thinning, init=init, chain_method=chain_method, progress=progress, post_warmup_state=post_warmup_state, **kwargs, )
[docs] def emcee( self, warmup: int = 5000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, n_parallel: int | None = None, progress: bool = True, post_warmup_state: Sequence | None = None, tune: bool = False, ignore_nan: bool = False, warmup_kwargs: dict | None = None, sampling_kwargs: dict | None = None, ) -> PosteriorResult: """Run :mod:`emcee`'s affine-invariant ensemble sampling. Affine-invariant ensemble sampling [1]_ is a gradient-free method that informs Metropolis-Hastings proposals by sharing information between chains. Suitable for low to moderate dimensional models. Generally, `chains` should be at least twice the dimensionality of the model. .. note:: This sampler must be used with even `chains` > 1. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 5000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is the dimension of model parameters. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. n_parallel : int, optional Number of parallel samplers to run. The default is ``jax.local_device_count()``. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : sequence, optional The state before the sampling phase. The sampling will start from the given state if provided. tune : bool, optional If True, the parameters of some moves will be automatically tuned. ignore_nan : bool, optional Whether to transform a NaN log probability to a large negative number (-1e300). The default is False. .. warning:: Setting ``ignore_nan=True`` may fail to spot potential issues with model computation. warmup_kwargs: dict, optional Extra parameters passed to :class:`emcee.EnsembleSampler` for warm-up phase. sampling_kwargs: dict | None = None, Extra parameters passed to :class:`emcee.EnsembleSampler` for sampling phase. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067), Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. """ init = self._check_init(init) n_parallel = get_parallel_number(n_parallel) sampler = EmceeSampler( numpyro_model=self._helper.numpyro_model, init_params=init, ignore_nan=ignore_nan, seed=self._helper.seed['mcmc'], ) samples, states = sampler.run( warmup=warmup, steps=steps, chains=chains, thinning=thinning, n_parallel=n_parallel, tune=tune, progress=progress, states=post_warmup_state, warmup_kwargs=warmup_kwargs, sampling_kwargs=sampling_kwargs, ) ess, reff = self._get_ess(samples, n_parallel) return self._generate_results( samples=samples, ess=ess, reff=reff, sampler_state=states, inference_library='emcee', )
[docs] def zeus( self, warmup: int = 3000, steps: int = 5000, chains: int | None = None, thinning: int = 1, init: dict[str, float] | None = None, n_parallel: int | None = None, progress: bool = True, post_warmup_state: Sequence | None = None, tune: bool = True, ignore_nan: bool = False, warmup_kwargs: dict | None = None, sampling_kwargs: dict | None = None, ) -> PosteriorResult: """Run :mod:`zeus`' ensemble slice sampling. Ensemble slice sampling [1]_ is a gradient free method that finds better slice sampling directions by sharing information between chains. Suitable for low to moderate dimensional models. Generally, `chains` should be at least twice the dimensionality of the model. .. note:: This sampler must be used with even number `chains` > 1. Parameters ---------- warmup : int, optional Number of warmup steps. The default is 5000. steps : int, optional Number of steps to run for each chain. The default is 5000. chains : int, optional Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is the dimension of model parameters. thinning: int, optional For each chain, every `thinning` step is retained, and the other steps are discarded. The total steps for each chain are `steps` * `thinning`. The default is 1. init : dict, optional Initial parameter for sampler to start from. n_parallel : int, optional Number of parallel samplers to run. The default is ``jax.local_device_count()``. progress : bool, optional Whether to show progress bars during sampling. The default is True. post_warmup_state : sequence, optional The state before the sampling phase. The sampling will start from the given state if provided. tune : bool, optional If True, the parameters of some moves will be automatically tuned. ignore_nan : bool, optional Whether to transform a NaN log probability to a large negative number (-1e300). The default is False. .. warning:: Setting ``ignore_nan=True`` may fail to spot potential issues with model computation. warmup_kwargs: dict, optional Extra parameters passed to :class:`zeus.EnsembleSampler` for warm-up phase. sampling_kwargs: dict | None = None, Extra parameters passed to :class:`zeus.EnsembleSampler` for sampling phase. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference (https://academic.oup.com/mnras/article/508/3/3589/6381726), Minas Karamanis, Florian Beutler, and John A. Peacock. .. [2] Ensemble slice sampling (https://link.springer.com/article/10.1007/s11222-021-10038-2), Minas Karamanis, Florian Beutler. """ init = self._check_init(init) n_parallel = get_parallel_number(n_parallel) sampler = ZeusSampler( numpyro_model=self._helper.numpyro_model, init_params=init, ignore_nan=ignore_nan, seed=self._helper.seed['mcmc'], ) samples, states = sampler.run( warmup=warmup, steps=steps, chains=chains, thinning=thinning, n_parallel=n_parallel, tune=tune, progress=progress, states=post_warmup_state, warmup_kwargs=warmup_kwargs, sampling_kwargs=sampling_kwargs, ) ess, reff = self._get_ess(samples, n_parallel) return self._generate_results( samples=samples, ess=ess, reff=reff, sampler_state=states, inference_library='zeus-mcmc', )
[docs] def jaxns( self, max_samples: int = 2**17, num_live_points: int | None = None, s: int | None = None, k: int | None = None, c: int | None = None, devices: list[Device] | None = None, difficult_model: bool = False, parameter_estimation: bool = False, verbose: bool = False, term_cond: dict | None = None, **kwargs: dict, ) -> PosteriorResult: """Run :mod:`jaxns`'s implementation of nested sampling. .. note:: Parameters `s`, `k`, and `c` are defined in the paper [1]_. For more information of the sampler parameters, see ref [1]_ [2]_. Parameters ---------- max_samples : int, optional Maximum number of posterior samples. The default is 131072. num_live_points : int, optional Approximate number of live points. The default is `c` * (`k` + 1). s : int, optional Number of slices per dimension. The default is 5. k : int, optional Number of phantom samples. The default is 0. c : int, optional Number of parallel Markov chains. The default is 30 * `D`, where `D` is the dimension of model parameters. It takes effect only for num_live_points=None. devices : list, optional Devices to use. Defaults to all available devices. difficult_model : bool, optional If True, uses more robust default settings (`s` = 10 and `c` = 50 * `D`). It takes effect only for `num_live_points` = None, `s` = None or `c` = None. Defaults to False. parameter_estimation : bool, optional If True, uses more robust default settings for parameter estimation (`k` = `D`). It takes effect only for `k` = None. Defaults to False. verbose : bool, optional Print progress information. The default is False. term_cond : dict, optional Termination conditions for the sampling. The default is as in :class:`jaxns.TermCondition`. **kwargs : dict Extra parameters passes to :class:`jaxns.DefaultNestedSampler`. Returns ------- PosteriorResult The posterior sampling result. References ---------- .. [1] `Phantom-Powered Nested Sampling <https://arxiv.org/abs/2312.11330>`__ .. [2] `JAXNS API doc <https://jaxns.readthedocs.io/en/latest/api/jaxns/index.html#jaxns.DefaultNestedSampler>`__ """ constructor_kwargs = { 'max_samples': max_samples, 'num_live_points': num_live_points, 's': s, 'k': k, 'c': c, 'devices': devices, 'difficult_model': difficult_model, 'parameter_estimation': parameter_estimation, 'verbose': verbose, } constructor_kwargs.update(kwargs) termination_kwargs = {'dlogZ': 1e-4} if term_cond is not None: termination_kwargs.update(term_cond) sampler = JAXNSSampler( self._helper.numpyro_model, constructor_kwargs=constructor_kwargs, termination_kwargs=termination_kwargs, ) print('Running nested sampling of JAXNS...') t0 = time.time() sampler.run(rng_key=jax.random.PRNGKey(self._helper.seed['mcmc'])) print(f'Sampling completed in {time.time() - t0:.2f} s') helper = self._helper result = sampler._results # get posterior samples total = result.total_num_samples rng_key = jax.random.PRNGKey(helper.seed['mcmc']) samples = jax.tree.map( lambda x: x[None, ...], sampler.get_samples(rng_key, total), ) # effective sample size overall_ess = int(result.ESS) ess = dict.fromkeys(self._helper.params_names['all'], overall_ess) # relative mcmc efficiency reff = float(overall_ess / result.total_num_samples) # model evidence lnZ = (float(result.log_Z_mean), float(result.log_Z_uncert)) return self._generate_results( samples=samples, ess=ess, reff=reff, lnZ=lnZ, inference_library='jaxns', )
[docs] def nautilus( self, ess: int = 3000, ignore_nan: bool = False, *, constructor_kwargs: dict | None = None, termination_kwargs: dict | None = None, ) -> PosteriorResult: """Run :mod:`nautilus`'s implementation of nested sampling. Parameters ---------- ess : int, optional The desired effective sample size. ignore_nan : bool, optional Whether to transform a NaN log probability to a large negative number (-1e300). The default is False. .. warning:: Setting ``ignore_nan=True`` may fail to spot potential issues with model computation. constructor_kwargs : dict, optional Extra parameters passed to :class:`nautilus.Sampler`. termination_kwargs : dict, optional Extra parameters passed to :class:`nautilus.Sampler.run()`. """ from elisa.infer.samplers.ns.nautilus import NautilusSampler if constructor_kwargs is None: constructor_kwargs = {} else: constructor_kwargs = dict(constructor_kwargs) constructor_kwargs.setdefault('pool', get_parallel_number(None)) if termination_kwargs is None: termination_kwargs = {} else: termination_kwargs = dict(termination_kwargs) termination_kwargs['n_eff'] = int(ess) print('Running nested sampling of Nautilus...') sampler = NautilusSampler( numpyro_model=self._helper.numpyro_model, seed=self._helper.seed['mcmc'], ignore_nan=ignore_nan, **constructor_kwargs, ) t0 = time.time() samples = sampler.run(**termination_kwargs) print(f'Sampling completed in {time.time() - t0:.2f} s') # format posterior samples samples = jax.tree.map(lambda x: x[None], samples) # effective sample size ess_overall = sampler.ess ess = dict.fromkeys(self._helper.params_names['all'], ess_overall) # relative mcmc efficiency total_sample = len(samples[self._helper.params_names['all'][0]]) reff = float(ess_overall / total_sample) # model evidence lnZ = (sampler.lnZ, None) return self._generate_results( samples=samples, ess=ess, reff=reff, lnZ=lnZ, inference_library='nautilus-sampler', )
[docs] def ultranest( self, ess: int = 3000, ignore_nan: bool = False, viz_params: list[str] | None = None, print_result: bool = True, *, constructor_kwargs: dict | None = None, termination_kwargs: dict | None = None, read_file_config: dict | None = None, ) -> PosteriorResult: """Run :mod:`ultranest`'s implementation of nested sampling. Parameters ---------- ess : int, optional The desired effective sample size. ignore_nan : bool, optional Whether to transform a NaN log probability to a large negative number (-1e300). The default is False. .. warning:: Setting ``ignore_nan=True`` may fail to spot potential issues with model computation. viz_params : list, optional Parameters to visualize during sampling. The default is all. print_result : bool, optional Whether to print sampling result. The default is True. constructor_kwargs : dict, optional Extra parameters passed to :class:`ultranest.ReactiveNestedSampler`. termination_kwargs : dict, optional Extra parameters passed to :class:`ultranest.ReactiveNestedSampler.run()`. read_file_config : dict, optional Read the log file from a previous run. The dictionary should contain the log directory and other optional parameters. It should be noted that when providing this keyword argument, the sampler will not run, but read the log file instead and make sure the data and model settings are the same as the previous run. """ from elisa.infer.samplers.ns.ultranest import UltraNestSampler if constructor_kwargs is None: constructor_kwargs = {} else: constructor_kwargs = dict(constructor_kwargs) if termination_kwargs is None: termination_kwargs = {} else: termination_kwargs = dict(termination_kwargs) termination_kwargs['min_ess'] = ess if viz_params is None: viz_params = self._helper.params_names['all'] print('Running nested sampling of UltraNest...') t0 = time.time() sampler = UltraNestSampler( numpyro_model=self._helper.numpyro_model, seed=self._helper.seed['mcmc'], ignore_nan=ignore_nan, **constructor_kwargs, ) samples = sampler.run( viz_sample_names=viz_params, read_file_config=read_file_config, **termination_kwargs, ) if print_result: sampler.print_results() print(f'Sampling completed in {time.time() - t0:.2f} s') # format posterior samples samples = jax.tree.map(lambda x: x[None], samples) # effective sample size ess_overall = sampler.ess ess = dict.fromkeys(self._helper.params_names['all'], ess_overall) # relative mcmc efficiency total_sample = len(samples[self._helper.params_names['all'][0]]) reff = float(ess_overall / total_sample) # model evidence lnZ = sampler.lnZ return self._generate_results( samples=samples, ess=ess, reff=reff, lnZ=lnZ, inference_library='ultranest', )