Source code for elisa.infer.samplers.ns.nautilus
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
import multiprocess as mp
import nautilus
import nautilus.pool as nautilus_pool
from elisa.infer.samplers.util import uniform_reparam_model
if TYPE_CHECKING:
from collections.abc import Callable
from numpy.typing import NDArray
[docs]
class NautilusSampler:
def __init__(
self,
numpyro_model: Callable,
model_args: tuple = (),
model_kwargs: dict | None = None,
seed: int = 42,
ignore_nan: bool = False,
**kwargs: dict,
):
if ignore_nan:
warnings.warn(
'setting `ignore_nan` to True may fail to spot potential '
'issues of the model',
Warning,
)
self._model_info = mi = uniform_reparam_model(
numpyro_model,
model_args,
model_kwargs,
rng_seed=seed,
)
@jax.jit
def log_prob_fn(cube_and_derived):
log_p = mi.log_prob_fn(mi.unravel(cube_and_derived[: mi.ndim]))
if ignore_nan:
log_p = jnp.nan_to_num(log_p, nan=-1e300)
return log_p
if 'pool' in kwargs:
kwargs['vectorized'] = False
old_method = mp.get_start_method()
if old_method != 'spawn':
mp.set_start_method('spawn', force=True)
else:
old_method = ''
# monkey patching the pool for compatibility with JAX
old_pool = nautilus_pool.Pool
nautilus_pool.Pool = mp.Pool
else:
kwargs['vectorized'] = True
log_prob_fn = jax.jit(jax.vmap(log_prob_fn))
old_method = ''
old_pool = None
self._sampler = nautilus.Sampler(
prior=lambda x: x,
likelihood=lambda x: jax.device_get(log_prob_fn(x)),
n_dim=mi.ndim,
pass_dict=False,
seed=seed,
**kwargs,
)
if old_method:
mp.set_start_method(old_method, force=True)
if old_pool is not None:
nautilus_pool.Pool = old_pool
[docs]
def run(self, **kwargs) -> dict[str, NDArray[float]]:
kwargs.setdefault('verbose', True)
kwargs['discard_exploration'] = True
sampler = self._sampler
success = sampler.run(**kwargs)
if success:
u_samples, *_ = sampler.posterior(
return_as_dict=False,
equal_weight=True,
)
u_samples = jax.vmap(self._model_info.unravel)(u_samples)
samples = jax.vmap(self._model_info.postprocess_fn)(u_samples)
samples = jax.device_get(samples)
return samples
else:
raise RuntimeError(
'Sampling failed due to limits were reached, please set a '
'larger `n_like_max` or `timeout`. You can also resume the '
'sampler from previous one, providing `filepath` and `resume`.'
)
@property
def ess(self) -> int:
return int(self._sampler.n_eff)
@property
def lnZ(self) -> float | None:
return self._sampler.log_z