elisa.infer.samplers.ensemble.numpyro#

class NumpyroEnsembleSampler(walkers: int, *args, **kwargs)[source]#

Bases: MCMCKernel

Wrapper kernel to run the ensemble sampler as a single MCMC chain.

To collect the posterior correctly, get samples in shape of (n_parallel, n_steps, n_walkers) by

from numpyro.infer import MCMC
kernel = NumpyroEnsembleSampler(...)
mcmc = MCMC(kernel, ...)
mcmc.run(...)
samples = mcmc.get_samples(group_by_chain=True)

Then combine the walkers from the same sampler by

import jax
import jax.numpy as jnp
samples = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), samples)
samples = jax.tree.map(
    lambda x: jnp.reshape(
        x,
        (x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]),
    ),
    samples,
)

Attributes

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

is_ensemble_kernel

Denotes whether the kernel is an ensemble kernel.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample.

Methods

get_diagnostics_str(state)

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params, ...)

Initialize the MCMCKernel and return an initial state to begin sampling from.

postprocess_fn(args, kwargs)

Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.

sample(state, model_args, model_kwargs)

Given the current state, return the next state using the given transition kernel.

property sample_field#

The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by postprocess_fn() and for reporting results in MCMC.print_summary().

postprocess_fn(args, kwargs)[source]#

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.

  • model_kwargs – Keyword arguments to the model.

init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]#

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (Array) – Random number generator key to initialize the kernel.

  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.

  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample(state, model_args, model_kwargs)[source]#

Given the current state, return the next state using the given transition kernel.

Parameters:
  • state

    A pytree class representing the state for the kernel. For HMC, this is given by HMCState. In general, this could be any class that supports getattr.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

Next state.

class NumPyroAIES(walkers: int, *args, **kwargs)[source]#

Bases: NumpyroEnsembleSampler

Attributes

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

is_ensemble_kernel

Denotes whether the kernel is an ensemble kernel.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample.

Methods

get_diagnostics_str(state)

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params, ...)

Initialize the MCMCKernel and return an initial state to begin sampling from.

postprocess_fn(args, kwargs)

Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.

sample(state, model_args, model_kwargs)

Given the current state, return the next state using the given transition kernel.

get_diagnostics_str(state)[source]#

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

class NumPyroESS(walkers: int, *args, **kwargs)[source]#

Bases: NumpyroEnsembleSampler

Attributes

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

is_ensemble_kernel

Denotes whether the kernel is an ensemble kernel.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample.

Methods

get_diagnostics_str(state)

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params, ...)

Initialize the MCMCKernel and return an initial state to begin sampling from.

postprocess_fn(args, kwargs)

Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.

sample(state, model_args, model_kwargs)

Given the current state, return the next state using the given transition kernel.