Source code for elisa.infer.samplers.ensemble.numpyro
from __future__ import annotations
import jax
from numpyro.infer.ensemble import AIES, ESS
from numpyro.infer.mcmc import MCMCKernel
from numpyro.util import is_prng_key
[docs]
class NumpyroEnsembleSampler(MCMCKernel):
"""Wrapper kernel to run the ensemble sampler as a single MCMC chain.
To collect the posterior correctly, get samples in shape of
(n_parallel, n_steps, n_walkers) by
.. code-block:: python
from numpyro.infer import MCMC
kernel = NumpyroEnsembleSampler(...)
mcmc = MCMC(kernel, ...)
mcmc.run(...)
samples = mcmc.get_samples(group_by_chain=True)
Then combine the walkers from the same sampler by
.. code-block:: python
import jax
import jax.numpy as jnp
samples = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), samples)
samples = jax.tree.map(
lambda x: jnp.reshape(
x,
(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]),
),
samples,
)
"""
_kernel: type[AIES | ESS]
def __init__(self, walkers: int, *args, **kwargs):
self._walkers = int(walkers)
self._sampler = self._kernel(*args, **kwargs)
@property
def sample_field(self):
return 'z'
[docs]
def postprocess_fn(self, args, kwargs):
return jax.vmap(self._sampler.postprocess_fn(args, kwargs))
[docs]
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
if not is_prng_key(rng_key):
raise NotImplementedError(
"EnsembleSampler only supports chain_method='parallel' or "
"chain_method='sequential'. Please put in a feature request "
'if it would be useful to be used in vectorized mode.'
)
rng_keys = jax.random.split(rng_key, self._walkers)
if init_params is not None:
if not all(
param.shape[0] == self._walkers
for param in jax.tree.leaves(init_params)
):
raise ValueError(
'The batch dimension of each param must match chains'
)
return self._sampler.init(
rng_keys, num_warmup, init_params, model_args, model_kwargs
)
[docs]
def sample(self, state, model_args, model_kwargs):
return self._sampler.sample(state, model_args, model_kwargs)
[docs]
class NumPyroAIES(NumpyroEnsembleSampler):
_kernel = AIES
[docs]
def get_diagnostics_str(self, state):
return f'acc. prob={state.inner_state.mean_accept_prob:.2f}'
[docs]
class NumPyroESS(NumpyroEnsembleSampler):
_kernel = ESS