Source code for elisa.infer.samplers.ns.jaxns

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""Nested sampling of jaxns.

This module is adapted from
https://github.com/pyro-ppl/numpyro/raw/master/numpyro/contrib/nested_sampling.py
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax import random
from numpyro.handlers import reparam, seed, trace
from numpyro.infer import Predictive
from numpyro.infer.util import (
    _guess_max_plate_nesting,
    _validate_model,
    log_density,
)

from elisa.infer.samplers.util import UniformReparam


[docs] class JAXNSSampler: """ (EXPERIMENTAL) A wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_ , a nested sampling package based on JAX. See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research. .. note:: To enumerate over a discrete latent variable, you can add the keyword `infer={"enumerate": "parallel"}` to the corresponding `sample` statement. .. note:: To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program ``numpyro.enable_x64()``. **References** 1. *JAXNS: a high-performance nested sampling package based on JAX*, Joshua G. Albert (https://arxiv.org/abs/2012.15286) :param callable model: a call with NumPyro primitives :param dict constructor_kwargs: additional keyword arguments to construct an upstream :class:`jaxns.NestedSampler` instance. :param dict termination_kwargs: keyword arguments to terminate the sampler. Please refer to the upstream :meth:`jaxns.NestedSampler.__call__` method. **Example** .. doctest:: >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.nested_sampling import NestedSampler >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(0), (2000, 3)) >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1)) >>> >>> def model(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3])) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), ... obs=labels) >>> >>> ns = NestedSampler(model) >>> ns.run(random.PRNGKey(2), data, labels) >>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000) >>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05 >>> print(jnp.mean(samples['coefs'], axis=0)) # doctest: +SKIP [0.93661342 1.95034876 2.86123884] """ def __init__( self, model, *, constructor_kwargs=None, termination_kwargs=None, ): from jaxns.utils import NestedSamplerResults self.model = model self.constructor_kwargs = ( constructor_kwargs if constructor_kwargs is not None else {} ) self.termination_kwargs = ( termination_kwargs if termination_kwargs is not None else {} ) self._samples = None self._log_weights = None self._results: NestedSamplerResults | None = None
[docs] def run(self, rng_key, *args, **kwargs): """ Run the nested samplers and collect weighted samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: The arguments needed by the `model`. :param kwargs: The keyword arguments needed by the `model`. """ from jaxns import Model, Prior, TerminationCondition from jaxns.public import DefaultNestedSampler import tensorflow_probability.substrates.jax.distributions as tfpd rng_sampling, rng_predictive = random.split(rng_key) # reparam the model so that latent sites have Uniform(0, 1) priors seeded = jax.jit(seed(self.model, rng_key)) prototype_trace = trace(seeded).get_trace(*args, **kwargs) param_names = [ site['name'] for site in prototype_trace.values() if site['type'] == 'sample' and not site['is_observed'] and site['infer'].get('enumerate', '') != 'parallel' ] deterministics = [ site['name'] for site in prototype_trace.values() if site['type'] == 'deterministic' ] reparam_model = reparam( self.model, config={k: UniformReparam() for k in param_names} ) # enable enumerate if needed has_enum = any( site['type'] == 'sample' and site['infer'].get('enumerate', '') == 'parallel' for site in prototype_trace.values() ) if has_enum: from numpyro.contrib.funsor import ( enum, log_density as log_density_, ) max_plate_nesting = _guess_max_plate_nesting(prototype_trace) _validate_model(prototype_trace) reparam_model = enum(reparam_model, -max_plate_nesting - 1) else: log_density_ = log_density def log_likelihood(params): params = {f'u_{k}': v for k, v in params.items()} return log_density_(reparam_model, args, kwargs, params)[0] # use NestedSampler with identity prior chain def prior_model(): params = {} for name in param_names: shape = prototype_trace[name]['fn'].shape() param = yield Prior( tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)), name=f'u_{name}', ) params[name] = param return params model = Model(prior_model=prior_model, log_likelihood=log_likelihood) default_constructor_kwargs = dict( num_live_points=model.U_ndims * 25, devices=jax.devices(), max_samples=1e4, ) default_termination_kwargs = dict(dlogZ=1e-4) # Fill-in missing values with defaults. This allows user to inspect what was actually used by inspecting # these dictionaries list( map( lambda item: self.constructor_kwargs.setdefault(*item), default_constructor_kwargs.items(), ) ) list( map( lambda item: self.termination_kwargs.setdefault(*item), default_termination_kwargs.items(), ) ) default_ns = DefaultNestedSampler( model=model, **self.constructor_kwargs, ) # TODO: check if this is necessary # jit when running on single device if len(default_ns.nested_sampler.devices) == 1: run_default_ns = jax.jit(default_ns) else: run_default_ns = default_ns termination_reason, state = run_default_ns( rng_sampling, term_cond=TerminationCondition(**self.termination_kwargs), ) results = default_ns.to_results( termination_reason=termination_reason, state=state ) # transform base samples back to original domains # Here we only transform the first valid num_samples samples # NB: the number of weighted samples obtained from jaxns is results.num_samples # and only the first num_samples values of results.samples are valid. num_samples = results.total_num_samples samples = results.samples predictive = Predictive( reparam_model, samples, return_sites=param_names + deterministics ) samples = predictive(rng_predictive, *args, **kwargs) # replace base samples in jaxns results by transformed samples self._results = results._replace(samples=samples)
[docs] def get_samples(self, rng_key, num_samples): """ Draws samples from the weighted samples collected from the run. :param random.PRNGKey rng_key: Random number generator key to be used to draw samples. :param int num_samples: The number of samples. :return: a dict of posterior samples """ from jaxns import resample if self._results is None: raise RuntimeError( 'NestedSampler.run(...) method should be called first to obtain results.' ) weighted_samples, sample_weights = self.get_weighted_samples() return resample( rng_key, weighted_samples, sample_weights, S=num_samples, replace=True, )
[docs] def get_weighted_samples(self): """ Gets weighted samples and their corresponding log weights. """ if self._results is None: raise RuntimeError( 'NestedSampler.run(...) method should be called first to obtain results.' ) return self._results.samples, self._results.log_dp_mean
[docs] def print_summary(self): """ Print summary of the result. This is a wrapper of :func:`jaxns.utils.summary`. """ from jaxns import summary if self._results is None: raise RuntimeError( 'NestedSampler.run(...) method should be called first to obtain results.' ) summary(self._results)
[docs] def diagnostics(self): """ Plot diagnostics of the result. This is a wrapper of :func:`jaxns.plotting.plot_diagnostics` and :func:`jaxns.plotting.plot_cornerplot`. """ from jaxns import plot_cornerplot, plot_diagnostics if self._results is None: raise RuntimeError( 'NestedSampler.run(...) method should be called first to obtain results.' ) plot_diagnostics(self._results) plot_cornerplot(self._results)