elisa.infer.samplers.blackjax.nuts#

class BlackJAXNUTSState(i: int, z: dict, z_grad: dict, potential_energy: float, energy: float, r: dict, num_steps: int, tree_depth: int, accept_prob: float, mean_accept_prob: float, diverging: bool, step_size: float, inverse_mass_matrix: MetricTypes, adapt_state: window_adaptation.WindowAdaptationState, rng_key: jax.Array)[source]#

Bases: NamedTuple

State of the BlackJAX NUTS.

Methods

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

i: int#

The iteration number.

z: dict#

Python collection representing values (unconstrained samples from the posterior) at latent sites.

z_grad: dict#

Gradient of potential energy w.r.t. latent sample sites.

potential_energy: float#

Potential energy computed at the given value of z.

energy: float#

Sum of potential energy and kinetic energy of the current state.

r: dict#

The current momentum variable. If this is None, a new momentum variable will be drawn at the beginning of each sampling step.

num_steps: int#

Number of steps in the Hamiltonian trajectory (for diagnostics).

tree_depth: int#

Tree depth of the current trajectory.

accept_prob: float#

Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.

mean_accept_prob: float#

Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).

diverging: bool#

Whether the current trajectory is diverging.

step_size: float#

Step size to be used by the integrator in the next iteration.

inverse_mass_matrix: MetricTypes#

The inverse mass matrix to be used for the next iteration.

adapt_state: window_adaptation.WindowAdaptationState#

The current window adaption state of the NUTS.

rng_key: jax.Array#

Random number generator seed used for the iteration.

class BlackJAXNUTS(model: Callable | None = None, potential_fn: Callable | None = None, init_strategy: Callable = <function init_to_uniform>, dense_mass: bool = True, initial_step_size: float = 1.0, target_accept_prob: float = 0.8, max_tree_depth: int = 10, divergence_threshold: float = 1000.0, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator>)[source]#

Bases: MCMCKernel

NUTS implementation of BlackJAX, with automatic window adaptation.

Attributes

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

is_ensemble_kernel

Denotes whether the kernel is an ensemble kernel.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample.

Methods

get_diagnostics_str(state)

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params, ...)

Initialize the MCMCKernel and return an initial state to begin sampling from.

postprocess_fn(model_args, model_kwargs)

Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.

sample(state, model_args, model_kwargs)

Given the current state, return the next state using the given transition kernel.

postprocess_fn(model_args, model_kwargs)[source]#

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.

  • model_kwargs – Keyword arguments to the model.

init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]#

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.

  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.

  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample(state, model_args, model_kwargs)[source]#

Given the current state, return the next state using the given transition kernel.

Parameters:
  • state

    A pytree class representing the state for the kernel. For HMC, this is given by HMCState. In general, this could be any class that supports getattr.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

Next state.

property sample_field#

The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by postprocess_fn() and for reporting results in MCMC.print_summary().

property default_fields#

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

get_diagnostics_str(state)[source]#

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.