Source code for elisa.infer.samplers.ensemble.zeus

from __future__ import annotations

from itertools import permutations
from typing import TYPE_CHECKING

import numpy as np
import zeus
import zeus.ensemble as zeus_ensemble
import zeus.moves as zeus_moves

from elisa.infer.samplers.ensemble.base import (
    EnsembleSampler,
    EnsembleSamplerState,
)

if TYPE_CHECKING:
    from multiprocessing import Queue

    from numpy.typing import NDArray


[docs] class ZeusSampler(EnsembleSampler):
[docs] def get_sampling_fn( self, chains: int, warmup: int, steps: int, thinning: int, tune: bool | None, warmup_kwargs: dict, sampling_kwargs: dict, ): ndim = self._ndim if tune is None: tune = True warmup_kwargs.setdefault('verbose', False) sampling_kwargs.setdefault('verbose', False) log_prob_fn = self._log_prob_with_blobs blobs_dtype = self._blobs_dtype def sampling_fn( sampler_id: int, state: EnsembleSamplerState, queue: Queue, ): np.random.set_state(state.random_state) start = state.coords sampler1 = zeus.EnsembleSampler( chains, ndim, log_prob_fn, args=None, kwargs=None, tune=tune, pool=None, vectorize=True, blobs_dtype=blobs_dtype, **warmup_kwargs, ) sampler2 = zeus.EnsembleSampler( chains, ndim, log_prob_fn, args=None, kwargs=None, tune=tune, pool=None, vectorize=True, blobs_dtype=blobs_dtype, **sampling_kwargs, ) queue.put((sampler_id, 'warmup')) for _ in sampler1.sample( start=start, iterations=warmup, progress=False, ): queue.put((sampler_id, 'update')) if warmup: start = sampler1.get_last_sample() sampler1.reset() queue.put((sampler_id, 'sample')) for _ in sampler2.sample( start=start, iterations=steps, thin_by=thinning, progress=False, ): queue.put((sampler_id, 'update')) queue.put((sampler_id, 'finish')) samples = sampler2.get_blobs() state = EnsembleSamplerState( coords=sampler2.get_last_sample(), random_state=np.random.get_state(legacy=True), ) return samples, state return sampling_fn
[docs] def get_random_state(self, seed: int) -> tuple: return np.random.RandomState(seed).get_state(legacy=True)
[docs] class DifferentialMove: """Improved DifferentialMove of zeus for reproducibility.""" def __init__(self, tune=True, mu0=1.0): self.tune: bool = tune self.mu0: float = mu0 self._nsamples: int | None = None self._perms: NDArray[int] | None = None self._nperms: int | None = None
[docs] def get_direction(self, X: NDArray[float], mu: float): nsamples = X.shape[0] if nsamples != self._nsamples: self._nsamples = nsamples self._perms = np.array(list(permutations(np.arange(nsamples), 2))) self._nperms = len(self._perms) idx = np.random.choice(self._nperms, self._nsamples, replace=False) pairs = self._perms[idx].T if not self.tune: mu = self.mu0 return 2.0 * mu * (X[pairs[0]] - X[pairs[1]]), self.tune
# Monkey patching the DifferentialMove for reproducibility zeus_ensemble.DifferentialMove = DifferentialMove zeus_moves.DifferentialMove = DifferentialMove