elisa.infer.samplers.blackjax.nuts#
- class BlackJAXNUTSState(i: int, z: dict, z_grad: dict, potential_energy: float, energy: float, r: dict, num_steps: int, tree_depth: int, accept_prob: float, mean_accept_prob: float, diverging: bool, step_size: float, inverse_mass_matrix: Any, adapt_state: Any, rng_key: jax.Array)[source]#
Bases:
NamedTupleState of the BlackJAX NUTS.
Methods
count(value, /)Return number of occurrences of value.
index(value[, start, stop])Return first index of value.
- z: dict#
Python collection representing values (unconstrained samples from the posterior) at latent sites.
- r: dict#
The current momentum variable. If this is None, a new momentum variable will be drawn at the beginning of each sampling step.
- accept_prob: float#
Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.
- class BlackJAXNUTS(model: Callable | None = None, potential_fn: Callable | None = None, init_strategy: Callable = <function init_to_uniform>, dense_mass: bool = True, initial_step_size: float = 1.0, target_accept_prob: float = 0.8, max_tree_depth: int = 10, divergence_threshold: float = 1000.0, integrator: Callable | None = None)[source]#
Bases:
MCMCKernelNUTS implementation of BlackJAX, with automatic window adaptation.
Attributes
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()is called).is_ensemble_kernelDenotes whether the kernel is an ensemble kernel.
The attribute of the state object passed to
sample()that denotes the MCMC sample.Methods
get_diagnostics_str(state)Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
init(rng_key, num_warmup, init_params, ...)Initialize the MCMCKernel and return an initial state to begin sampling from.
postprocess_fn(model_args, model_kwargs)Get a function that transforms unconstrained values at sample sites to values constrained to the site's support, in addition to returning deterministic sites in the model.
sample(state, model_args, model_kwargs)Given the current state, return the next state using the given transition kernel.
- postprocess_fn(model_args, model_kwargs)[source]#
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]#
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (Array) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- sample(state, model_args, model_kwargs)[source]#
Given the current state, return the next state using the given transition kernel.
- property sample_field#
The attribute of the state object passed to
sample()that denotes the MCMC sample. This is used bypostprocess_fn()and for reporting results inMCMC.print_summary().
- property default_fields#
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()is called).