elisa.infer.samplers.ensemble.numpyro#
- class NumpyroEnsembleSampler(walkers: int, *args, **kwargs)[source]#
Bases:
MCMCKernelWrapper kernel to run the ensemble sampler as a single MCMC chain.
To collect the posterior correctly, get samples in shape of (n_parallel, n_steps, n_walkers) by
from numpyro.infer import MCMC kernel = NumpyroEnsembleSampler(...) mcmc = MCMC(kernel, ...) mcmc.run(...) samples = mcmc.get_samples(group_by_chain=True)
Then combine the walkers from the same sampler by
import jax import jax.numpy as jnp samples = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), samples) samples = jax.tree.map( lambda x: jnp.reshape( x, (x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]), ), samples, )
Attributes
default_fieldsThe attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()is called).is_ensemble_kernelDenotes whether the kernel is an ensemble kernel.
The attribute of the state object passed to
sample()that denotes the MCMC sample.Methods
get_diagnostics_str(state)Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
init(rng_key, num_warmup, init_params, ...)Initialize the MCMCKernel and return an initial state to begin sampling from.
postprocess_fn(args, kwargs)Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.
sample(state, model_args, model_kwargs)Given the current state, return the next state using the given transition kernel.
- property sample_field#
The attribute of the state object passed to
sample()that denotes the MCMC sample. This is used bypostprocess_fn()and for reporting results inMCMC.print_summary().
- postprocess_fn(args, kwargs)[source]#
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]#
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- class NumPyroAIES(walkers: int, *args, **kwargs)[source]#
Bases:
NumpyroEnsembleSamplerAttributes
default_fieldsThe attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()is called).is_ensemble_kernelDenotes whether the kernel is an ensemble kernel.
sample_fieldThe attribute of the state object passed to
sample()that denotes the MCMC sample.Methods
get_diagnostics_str(state)Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
init(rng_key, num_warmup, init_params, ...)Initialize the MCMCKernel and return an initial state to begin sampling from.
postprocess_fn(args, kwargs)Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.
sample(state, model_args, model_kwargs)Given the current state, return the next state using the given transition kernel.
- class NumPyroESS(walkers: int, *args, **kwargs)[source]#
Bases:
NumpyroEnsembleSamplerAttributes
default_fieldsThe attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()is called).is_ensemble_kernelDenotes whether the kernel is an ensemble kernel.
sample_fieldThe attribute of the state object passed to
sample()that denotes the MCMC sample.Methods
get_diagnostics_str(state)Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
init(rng_key, num_warmup, init_params, ...)Initialize the MCMCKernel and return an initial state to begin sampling from.
postprocess_fn(args, kwargs)Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.
sample(state, model_args, model_kwargs)Given the current state, return the next state using the given transition kernel.