Source code for elisa.infer.samplers.ns.ultranest
from __future__ import annotations
import warnings
from functools import partial
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
import numpy as np
from ultranest import ReactiveNestedSampler, read_file
from elisa.infer.samplers.util import ravel_params_names, uniform_reparam_model
if TYPE_CHECKING:
from collections.abc import Callable
from numpy.typing import NDArray
[docs]
class UltraNestSampler:
def __init__(
self,
numpyro_model: Callable,
model_args: tuple = (),
model_kwargs: dict | None = None,
seed: int = 42,
ignore_nan: bool = False,
**kwargs: dict,
):
if ignore_nan:
warnings.warn(
'setting `ignore_nan` to True may fail to spot potential '
'issues of the model',
Warning,
)
self._model_info = mi = uniform_reparam_model(
numpyro_model,
model_args,
model_kwargs,
rng_seed=seed,
)
@jax.jit
@jax.vmap
def log_prob_fn(cube_and_derived):
log_p = mi.log_prob_fn(mi.unravel(cube_and_derived[: mi.ndim]))
if ignore_nan:
log_p = jnp.nan_to_num(log_p, nan=-1e300)
return log_p
self._log_prob_fn = log_prob_fn
self._sampler = None
self._sampler_constructor = partial(ReactiveNestedSampler, **kwargs)
self._seed = seed
[docs]
def run(
self,
viz_sample_names: list[str] | None = None,
read_file_config: dict | None = None,
**kwargs: dict,
) -> dict[str, NDArray[float]]:
mi = self._model_info
if viz_sample_names is None:
viz_sample_names = mi.params_names
else:
viz_sample_names = list(map(str, viz_sample_names))
@jax.jit
@jax.vmap
def transform(cube):
samples = mi.postprocess_fn(mi.unravel(cube))
viz = jnp.hstack([samples[i].ravel() for i in viz_sample_names])
return jnp.append(cube, viz)
params_names = []
for i in mi.params_dtype:
shape = i[2]
name = f'u_{i[0]}'
params_names.extend(ravel_params_names(name, shape))
samples_dtype = mi.params_dtype + mi.deterministic_dtype
derived_names = []
for i in viz_sample_names:
filtered = list(filter(lambda x: x[0] == i, samples_dtype))
if any(filtered):
shape = filtered[0][2]
derived_names.extend(ravel_params_names(i, shape))
if read_file_config is None:
prev_state = np.random.get_state()
np.random.seed(self._seed)
sampler = self._sampler = self._sampler_constructor(
param_names=params_names,
loglike=lambda x: jax.device_get(self._log_prob_fn(x)),
transform=lambda x: jax.device_get(transform(x)),
derived_param_names=derived_names,
vectorized=True,
)
sampler.run(**kwargs)
np.random.set_state(prev_state)
u_samples = sampler.results['samples'][:, : mi.ndim]
else:
read_file_config = dict(read_file_config)
read_file_config['x_dim'] = mi.ndim
sequence, final = read_file(**read_file_config)
results = sequence | final
u_samples = results['samples'][:, : mi.ndim]
u_samples = jax.vmap(mi.unravel)(u_samples)
samples = jax.vmap(mi.postprocess_fn)(u_samples)
samples = jax.device_get(samples)
return samples
[docs]
def print_results(self, use_unicode: bool = True):
if not hasattr(self._sampler, 'results'):
raise RuntimeError(
'no results found, please run the sampler first.'
)
self._sampler.print_results(use_unicode=use_unicode)
@property
def ess(self) -> int:
if not hasattr(self._sampler, 'results'):
return 0
else:
return int(self._sampler.results['ess'])
@property
def lnZ(self) -> tuple[float | None, float | None]:
if not hasattr(self._sampler, 'results'):
return None, None
else:
return (
self._sampler.results['logz'],
self._sampler.results['logzerr'],
)