Source code for elisa.infer.samplers.ensemble.emcee

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from emcee import EnsembleSampler as Sampler, State as EmceeState

from elisa.infer.samplers.ensemble.base import (
    EnsembleSampler,
    EnsembleSamplerState,
)

if TYPE_CHECKING:
    from multiprocessing import Queue

    from numpy.random import Generator


[docs] class EmceeSampler(EnsembleSampler):
[docs] def get_sampling_fn( self, chains: int, warmup: int, steps: int, thinning: int, tune: bool | None, warmup_kwargs: dict, sampling_kwargs: dict, ): ndim = self._ndim if chains is None: chains = 4 * ndim if tune is None: tune = False log_prob_fn = self._log_prob_with_blobs blobs_dtype = self._blobs_dtype def sampling_fn( sampler_id: int, state: EnsembleSamplerState, queue: Queue ): emcee_state = EmceeState( coords=state.coords, random_state=state.random_state, ) sampler1 = Sampler( chains, ndim, log_prob_fn, pool=None, args=None, kwargs=None, vectorize=True, blobs_dtype=blobs_dtype, parameter_names=None, **warmup_kwargs, ) sampler2 = Sampler( chains, ndim, log_prob_fn, pool=None, args=None, kwargs=None, vectorize=True, blobs_dtype=blobs_dtype, parameter_names=None, **sampling_kwargs, ) queue.put((sampler_id, 'warmup')) for s in sampler1.sample( emcee_state, iterations=warmup, tune=tune, store=False, progress=False, ): emcee_state = s queue.put((sampler_id, 'update')) queue.put((sampler_id, 'sample')) for s in sampler2.sample( emcee_state, iterations=steps, tune=tune, thin_by=thinning, store=True, progress=False, ): emcee_state = s queue.put((sampler_id, 'update')) queue.put((sampler_id, 'finish')) samples = sampler2.get_blobs() state = EnsembleSamplerState( coords=emcee_state.coords, random_state=emcee_state.random_state, ) return samples, state return sampling_fn
[docs] def get_random_state(self, seed: int) -> Generator: return np.random.default_rng(seed)