Source code for elisa.infer.samplers.ensemble.base

from __future__ import annotations

import threading
import warnings
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, NamedTuple

import jax
import jax.numpy as jnp
import multiprocess as mp
import numpy as np
from numpyro.infer.initialization import init_to_value
from tqdm.auto import tqdm

from elisa.infer.samplers.util import get_model_info

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from multiprocessing import Queue
    from typing import Any

    from numpy.typing import NDArray


[docs] class EnsembleSamplerState(NamedTuple): """Ensemble sampler state.""" coords: NDArray[float] """The current positions of the walkers in the parameter space.""" random_state: Any """The state of random number generator."""
[docs] class EnsembleSampler(metaclass=ABCMeta): def __init__( self, numpyro_model: Callable, init_params: dict[str, float] | None = None, ignore_nan: bool = False, seed: int = 42, model_args: tuple = (), model_kwargs: dict | None = None, ): """The base class for ensemble samplers. Parameters ---------- numpyro_model : callable The numpyro model to sample from. init_params : dict, optional The initial parameters for the model. ignore_nan : bool, optional Whether to transform log probability of NaN to -1e300. The default is False. seed : int, optional The random seed for reproducibility. model_args : tuple, optional The arguments to pass to `numpyro_model`. model_kwargs : dict, optional The keyword arguments to pass to `numpyro_model`. """ if ignore_nan: warnings.warn( 'setting `ignore_nan` to True may fail to spot potential ' 'issues of the model', Warning, ) if init_params is None: init_params = {} info = get_model_info( model=numpyro_model, init_strategy=init_to_value(values=init_params), model_args=model_args, model_kwargs=model_kwargs, validate_grad=False, rng_seed=seed, ) ndim = info.ndim blobs_dtype = info.params_dtype + info.deterministic_dtype blobs_names = [dt[0] for dt in blobs_dtype] init = info.init_ravel unravel = info.unravel log_prob_fn = info.log_prob_fn postprocess_fn = info.postprocess_fn @jax.vmap @jax.jit def _log_prob_with_blobs(z): z = unravel(z) log_prob = log_prob_fn(z) blobs = postprocess_fn(z) return log_prob, [blobs[i] for i in blobs_names] @jax.jit def log_prob_with_blobs(z): log_prob, blobs = _log_prob_with_blobs(z) if ignore_nan: log_prob = jnp.nan_to_num(log_prob, nan=-np.inf) return [ (log_prob[i], *(b[i] for b in blobs)) for i in range(len(log_prob)) ] self._ndim = ndim self._init = init self._log_prob_with_blobs = log_prob_with_blobs self._blobs_dtype = blobs_dtype self._seed = seed
[docs] def run( self, warmup: int = 5000, steps: int = 5000, chains: int | None = None, thinning: int = 1, n_parallel: int = 4, tune: bool | None = None, progress: bool = True, states: Sequence[EnsembleSamplerState] | None = None, warmup_kwargs: dict | None = None, sampling_kwargs: dict | None = None, ) -> tuple[dict[str, NDArray[float]], tuple[EnsembleSamplerState, ...]]: """Run the sampler. Parameters ---------- warmup : int, optional The warmup (burn-in) steps. The default is 5000. steps : int, optional The sampling steps. The default is 5000. chains : int, optional The number of walkers. The default is 4 * ndim. thinning : int Stores every `thinning` samples in the chain. When this is set, `steps` * `thinning` proposals will be made. The default is 1. n_parallel : int, optional Number of parallel samplers. The default is 4. tune : bool, optional Whether to tune the parameters of moves. Defaults to the corresponding sampler's default. progress : bool, optional Whether to display a progress bar. The default is True. states : sequence, optional The initial states of the samplers. warmup_kwargs: dict, optional Extra parameters passed to sampler constructor for warm-up phase. sampling_kwargs: dict | None = None, Extra parameters passed to sampler constructor for sampling phase. Returns ------- samples : dict The posterior samples. states : tuple The states of the samplers. """ ndim = self._ndim if chains is None: chains = 4 * ndim if states is None: seed0 = np.random.default_rng(self._seed).integers(2**32) seed_init, *seeds_mcmc = seed0 + np.arange(n_parallel + 1) rng_init = np.random.default_rng(seed_init) jitter = rng_init.uniform(0.9, 1.1, (n_parallel, chains, ndim)) init = jitter * np.full((chains, ndim), self._init) states = [ EnsembleSamplerState(i, self.get_random_state(j)) for i, j in zip(init, seeds_mcmc, strict=True) ] else: states = list(states) if len(states) != n_parallel: raise ValueError('states number should match n_parallel') if not all(isinstance(s, EnsembleSamplerState) for s in states): raise ValueError('states must be EnsembleSamplerState') for s in states: if s.coords.shape != (chains, ndim): raise ValueError( f"states' coords must have shape ({chains}, {ndim})" ) warmup = 0 if warmup_kwargs is None: warmup_kwargs = {} else: warmup_kwargs = dict(warmup_kwargs) if sampling_kwargs is None: sampling_kwargs = {} else: sampling_kwargs = dict(sampling_kwargs) old_method = mp.get_start_method() if old_method != 'spawn': mp.set_start_method('spawn', force=True) else: old_method = '' pbars = { i: tqdm( total=warmup + steps, desc='Initializing... ', position=i, disable=not progress, ) for i in range(1, n_parallel + 1) } queue = mp.Manager().Queue() listener_thread = threading.Thread( target=self._progress_listener, args=(queue, pbars, n_parallel), ) listener_thread.start() sampling_fn = self.get_sampling_fn( chains=chains, warmup=warmup, steps=steps, thinning=thinning, tune=tune, warmup_kwargs=warmup_kwargs, sampling_kwargs=sampling_kwargs, ) with mp.Pool(processes=n_parallel) as pool: results = [] for i in range(n_parallel): r = pool.apply_async( sampling_fn, args=(i + 1, states[i], queue), ) results.append(r) results = [r.get() for r in results] pool.close() pool.join() listener_thread.join() if old_method: mp.set_start_method(old_method, force=True) # get samples samples = [r[0] for r in results] # reshape samples from (n_step, n_walker) to (n_walker, n_step) samples = list(map(np.transpose, samples)) # merge chains from the same sampler into a single one samples = list(map(np.concatenate, samples)) # stack samples of different samplers samples = np.vstack(samples) # make samples a dict samples = {i: samples[i] for i in samples.dtype.names} # get states states = tuple(r[1] for r in results) return samples, states
@staticmethod def _progress_listener( queue: Queue, pbars: dict[int, tqdm], num_samplers: int, ): finished = 0 while finished < num_samplers: sampler_id, msg = queue.get() if msg == 'update': pbars[sampler_id].update(1) elif msg == 'warmup': pbars[sampler_id].set_description( f'Warm-up sampler {sampler_id}' ) elif msg == 'sample': pbars[sampler_id].set_description( f'Running sampler {sampler_id}', ) elif msg == 'finish': pbars[sampler_id].set_description( f'Finished sampler {sampler_id}', ) pbars[sampler_id].close() finished += 1
[docs] @abstractmethod def get_sampling_fn( self, chains: int, warmup: int, steps: int, thinning: int, tune: bool | None, warmup_kwargs: dict, sampling_kwargs: dict, ) -> Callable[ [int, EnsembleSamplerState, Queue], tuple[NDArray, EnsembleSamplerState], ]: """Generate the sampling function.""" pass
[docs] @abstractmethod def get_random_state(self, seed: int) -> Any: """Get the random state for the sampler.""" pass