elisa.infer.samplers.ns.jaxns#
Nested sampling of jaxns.
This module is adapted from https://github.com/pyro-ppl/numpyro/raw/master/numpyro/contrib/nested_sampling.py
- class JAXNSSampler(model, *, constructor_kwargs=None, termination_kwargs=None)[source]#
Bases:
object(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.
See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.
Note
To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.
Note
To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program
numpyro.enable_x64().References
JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (https://arxiv.org/abs/2012.15286)
- Parameters:
model (callable) – a call with NumPyro primitives
constructor_kwargs (dict) – additional keyword arguments to construct an upstream
jaxns.NestedSamplerinstance.termination_kwargs (dict) – keyword arguments to terminate the sampler. Please refer to the upstream
jaxns.NestedSampler.__call__()method.
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.nested_sampling import NestedSampler >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(0), (2000, 3)) >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1)) >>> >>> def model(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3])) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), ... obs=labels) >>> >>> ns = NestedSampler(model) >>> ns.run(random.PRNGKey(2), data, labels) >>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000) >>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05 >>> print(jnp.mean(samples['coefs'], axis=0)) [0.93661342 1.95034876 2.86123884]
Methods
Plot diagnostics of the result.
get_samples(rng_key, num_samples)Draws samples from the weighted samples collected from the run.
Gets weighted samples and their corresponding log weights.
Print summary of the result.
run(rng_key, *args, **kwargs)Run the nested samplers and collect weighted samples.
- run(rng_key, *args, **kwargs)[source]#
Run the nested samplers and collect weighted samples.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
args – The arguments needed by the model.
kwargs – The keyword arguments needed by the model.