from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
import blackjax.adaptation.window_adaptation as window_adaptation
import blackjax.mcmc.integrators as integrators
import jax
import jax.numpy as jnp
import jax.random as random
from blackjax.mcmc.hmc import HMCState
from blackjax.mcmc.nuts import build_kernel
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import init_to_uniform, initialize_model
from numpyro.util import identity, is_prng_key
if TYPE_CHECKING:
from collections.abc import Callable
from blackjax.mcmc.metrics import MetricTypes
[docs]
class BlackJAXNUTSState(NamedTuple):
"""State of the BlackJAX NUTS."""
i: int
"""The iteration number."""
z: dict
"""Python collection representing values (unconstrained samples from
the posterior) at latent sites."""
z_grad: dict
"""Gradient of potential energy w.r.t. latent sample sites."""
potential_energy: float
"""Potential energy computed at the given value of `z`."""
energy: float
"""Sum of potential energy and kinetic energy of the current state."""
r: dict
"""The current momentum variable. If this is None, a new momentum variable
will be drawn at the beginning of each sampling step."""
num_steps: int
"""Number of steps in the Hamiltonian trajectory (for diagnostics)."""
tree_depth: int
"""Tree depth of the current trajectory."""
accept_prob: float
"""Acceptance probability of the proposal. Note that `z` does not
correspond to the proposal if it is rejected."""
mean_accept_prob: float
"""Mean acceptance probability until current iteration during warmup
adaptation or sampling (for diagnostics)."""
diverging: bool
"""Whether the current trajectory is diverging."""
step_size: float
"""Step size to be used by the integrator in the next iteration."""
inverse_mass_matrix: MetricTypes
"""The inverse mass matrix to be used for the next iteration."""
adapt_state: window_adaptation.WindowAdaptationState
"""The current window adaption state of the NUTS."""
rng_key: jax.Array
"""Random number generator seed used for the iteration."""
[docs]
class BlackJAXNUTS(MCMCKernel):
"""NUTS implementation of BlackJAX, with automatic window adaptation."""
def __init__(
self,
model: Callable | None = None,
potential_fn: Callable | None = None,
init_strategy: Callable = init_to_uniform,
dense_mass: bool = True,
initial_step_size: float = 1.0,
target_accept_prob: float = 0.8,
max_tree_depth: int = 10,
divergence_threshold: float = 1000.0,
integrator: Callable = integrators.velocity_verlet,
):
if not (model is None) ^ (potential_fn is None):
raise ValueError(
'Only one of `model` or `potential_fn` must be specified.'
)
self._model = model
self._potential_fn = potential_fn
self._init_strategy = init_strategy
# Window adaption parameters
self._dense_mass = dense_mass
self._initial_step_size = initial_step_size
self._target_accept_prob = target_accept_prob
# NUTS kernel parameters
self._max_tree_depth = max_tree_depth
self._divergence_threshold = divergence_threshold
self._integrator = integrator
# Sampling related
self._postprocess_gen = None
self._mcmc_kernel = None
[docs]
def postprocess_fn(self, model_args, model_kwargs):
if self._postprocess_gen is None:
return identity
return self._postprocess_gen(*model_args, **model_kwargs)
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
model_info = initialize_model(
rng_key,
self._model,
init_strategy=self._init_strategy,
dynamic_args=True,
model_args=model_args,
model_kwargs=model_kwargs,
validate_grad=True,
)
init_params = model_info.param_info.z
potential_gen = model_info.potential_fn
postprocess_gen = model_info.postprocess_fn
model_kwargs = {} if model_kwargs is None else model_kwargs
potential_fn = potential_gen(*model_args, **model_kwargs)
self._potential_fn = potential_fn
self._postprocess_gen = postprocess_gen
return init_params
[docs]
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
# TODO: support chain_method='vectorized'
if not is_prng_key(rng_key):
raise NotImplementedError(
"BlackJAX's NUTS only supports chain_method='parallel' or "
"chain_method='sequential'. Please put in a feature request "
'if it would be useful to be used in vectorized mode.'
)
rng_key, rng_key_init_model, rng_key_wa = random.split(rng_key, 3)
# Initial parameters
init_params = self._init_state(
rng_key_init_model, model_args, model_kwargs, init_params
)
if self._potential_fn and init_params is None:
raise ValueError(
'`init_params` must be provided with `potential_fn`.'
)
# log posterior density function for blackjax
log_density = lambda z: -self._potential_fn(z)
# Build NUTS kernel
kernel = build_kernel(self._integrator, self._divergence_threshold)
# Initialize window adaption, adapted from blackjax.window_adaptation
adapt_init, adapt_step, adapt_final = window_adaptation.base(
is_mass_matrix_diagonal=not self._dense_mass,
target_acceptance_rate=self._target_accept_prob,
)
schedule = window_adaptation.build_schedule(num_warmup)
def wa_update(state, new_hmc_state, info):
return adapt_step(
state.adapt_state,
schedule[state.i],
new_hmc_state.position,
info.acceptance_rate,
)
def mcmc_kernel(state: BlackJAXNUTSState) -> BlackJAXNUTSState:
i = state.i + 1
(rng_key,) = random.split(state.rng_key, 1)
hmc_state = HMCState(
position=state.z,
logdensity=-state.potential_energy,
logdensity_grad=jax.tree.map(jnp.negative, state.z_grad),
)
step_size = jnp.where(
i <= num_warmup, state.adapt_state.step_size, state.step_size
)
inverse_mass_matrix = jnp.where(
i <= num_warmup,
state.adapt_state.inverse_mass_matrix,
state.inverse_mass_matrix,
)
new_state, info = kernel(
rng_key,
hmc_state,
log_density,
step_size,
inverse_mass_matrix,
max_num_doublings=self._max_tree_depth,
)
adapt_state = jax.lax.cond(
i <= num_warmup,
(state, new_state, info),
lambda args: wa_update(*args),
state.adapt_state,
identity,
)
step_size, inverse_mass_matrix = adapt_final(adapt_state)
n = jnp.where(i <= num_warmup, i, i - num_warmup)
new_mean_acc_prob = (
state.mean_accept_prob
+ (info.acceptance_rate - state.mean_accept_prob) / n
)
return BlackJAXNUTSState(
i=i,
z=new_state.position,
z_grad=jax.tree.map(jnp.negative, new_state.logdensity_grad),
potential_energy=-new_state.logdensity,
energy=info.energy,
r=info.momentum,
num_steps=info.num_integration_steps,
tree_depth=info.num_trajectory_expansions,
accept_prob=info.acceptance_rate,
mean_accept_prob=new_mean_acc_prob,
diverging=info.is_divergent,
step_size=step_size,
inverse_mass_matrix=inverse_mass_matrix,
adapt_state=adapt_state,
rng_key=rng_key,
)
self._mcmc_kernel = mcmc_kernel
init_adapt_state = adapt_init(init_params, self._initial_step_size)
pe, z_grad = jax.value_and_grad(self._potential_fn)(init_params)
return BlackJAXNUTSState(
i=0,
z=init_params,
z_grad=z_grad,
potential_energy=pe,
energy=jnp.nan,
r=dict.fromkeys(init_params, jnp.nan),
num_steps=-1,
tree_depth=-1,
accept_prob=jnp.nan,
mean_accept_prob=0.0,
diverging=False,
step_size=init_adapt_state.step_size,
inverse_mass_matrix=init_adapt_state.inverse_mass_matrix,
adapt_state=init_adapt_state,
rng_key=rng_key,
)
[docs]
def sample(self, state, model_args, model_kwargs):
return self._mcmc_kernel(state)
@property
def sample_field(self):
return 'z'
@property
def default_fields(self):
return 'z', 'diverging'
[docs]
def get_diagnostics_str(self, state):
return (
f'{state.num_steps} steps of size {state.step_size:.2e}. '
f'acc. prob={state.mean_accept_prob:.2f}'
)
def __getstate__(self):
state = self.__dict__.copy()
state['_mcmc_kernel'] = None
return state