elisa.infer.samplers.ns.jaxns#

Nested sampling of jaxns.

This module is adapted from pyro-ppl/numpyro

class JAXNSSampler(model, *, constructor_kwargs=None, termination_kwargs=None)[source]#

Bases: object

(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.

See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.

Note

To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.

Note

To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program numpyro.enable_x64().

References

  1. JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (https://arxiv.org/abs/2012.15286)

Parameters:
  • model (callable) – a call with NumPyro primitives

  • constructor_kwargs (dict) – additional keyword arguments to construct an upstream jaxns.NestedSampler instance.

  • termination_kwargs (dict) – keyword arguments to terminate the sampler. Please refer to the upstream jaxns.NestedSampler.__call__() method.

Example

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>>
>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
...                           obs=labels)
>>>
>>> ns = NestedSampler(model)
>>> ns.run(random.PRNGKey(2), data, labels)
>>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000)
>>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05
>>> print(jnp.mean(samples['coefs'], axis=0))
[0.93661342 1.95034876 2.86123884]

Methods

diagnostics()

Plot diagnostics of the result.

get_samples(rng_key, num_samples)

Draws samples from the weighted samples collected from the run.

get_weighted_samples()

Gets weighted samples and their corresponding log weights.

print_summary()

Print summary of the result.

run(rng_key, *args, **kwargs)

Run the nested samplers and collect weighted samples.

run(rng_key, *args, **kwargs)[source]#

Run the nested samplers and collect weighted samples.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.

  • args – The arguments needed by the model.

  • kwargs – The keyword arguments needed by the model.

get_samples(rng_key, num_samples)[source]#

Draws samples from the weighted samples collected from the run.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used to draw samples.

  • num_samples (int) – The number of samples.

Returns:

a dict of posterior samples

get_weighted_samples()[source]#

Gets weighted samples and their corresponding log weights.

print_summary()[source]#

Print summary of the result. This is a wrapper of jaxns.utils.summary().

diagnostics()[source]#

Plot diagnostics of the result. This is a wrapper of jaxns.plotting.plot_diagnostics() and jaxns.plotting.plot_cornerplot().