Source code for elisa.infer.samplers.util

from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, NamedTuple

import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax.flatten_util import ravel_pytree
from numpyro.handlers import reparam, seed, trace
from numpyro.infer.initialization import init_to_uniform
from numpyro.infer.reparam import Reparam
from numpyro.infer.util import (
    Predictive,
    _guess_max_plate_nesting,
    _validate_model,
    initialize_model,
    log_density,
)

if TYPE_CHECKING:
    from collections.abc import Callable

    from numpy import dtype, floating
    from numpy.typing import NDArray


[docs] class ModelInfo(NamedTuple): """Model information.""" ndim: int """Model dimension.""" init: dict[str, NDArray[floating]] """Initial parameters values in unconstrained space.""" init_ravel: NDArray[floating] """Raveled initial parameters values in unconstrained space.""" unravel: Callable[[NDArray[floating]], dict[str, NDArray[floating]]] """Function to unravel parameters values.""" log_prob_fn: Callable[[dict[str, NDArray[floating]]], float] """Log probability function given parameters in unconstrained space.""" postprocess_fn: Callable[ [dict[str, NDArray[floating]]], dict[str, NDArray[floating]], ] """Postprocess function given parameters in unconstrained space.""" params_names: list[str] """Names of parameters.""" params_dtype: list[tuple[str, dtype, tuple[int, ...]]] """NumPy dtypes of parameters in constrained space.""" deterministic_names: list[str] """Names of deterministic sites.""" deterministic_dtype: list[tuple[str, dtype, tuple[int, ...]]] """NumPy dtypes of deterministic sites."""
[docs] def get_model_info( model: Callable, init_strategy: Callable = init_to_uniform, model_args: tuple = (), model_kwargs: dict | None = None, forward_mode_differentiation: bool = False, validate_grad: bool = True, rng_seed: int = 42, ) -> ModelInfo: """Get model information.""" model_info = initialize_model( rng_key=jax.random.PRNGKey(rng_seed), model=model, init_strategy=init_strategy, model_args=model_args, model_kwargs=model_kwargs, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) potential_fn = model_info.potential_fn log_prob_fn = jax.jit(lambda z: -potential_fn(z)) postprocess_fn = jax.jit(model_info.postprocess_fn) init = model_info.param_info.z init_ravel, unravel = ravel_pytree(init) init_ravel = jax.device_get(init_ravel) samples = postprocess_fn(init) params_names = list(init.keys()) params_dtype = [ (i, samples[i].dtype, samples[i].shape) for i in params_names ] deterministic_names = [i for i in samples if i not in params_names] deterministic_dtype = [ (i, samples[i].dtype, samples[i].shape) for i in deterministic_names ] return ModelInfo( ndim=len(init_ravel), init=init, init_ravel=init_ravel, unravel=unravel, log_prob_fn=log_prob_fn, postprocess_fn=postprocess_fn, params_names=params_names, params_dtype=params_dtype, deterministic_names=deterministic_names, deterministic_dtype=deterministic_dtype, )
[docs] def ravel_params_names(name: str, shape: tuple[int, ...]) -> list[str]: """Ravel parameter names.""" if shape == (): return [str(name)] indices = np.indices(shape).reshape(len(shape), -1).T indices = indices.astype(str).tolist() return [f'{name}[{",".join(i)}]' for i in indices]
# >>> Codes below are adapted from numpyro.contrib.nested_sampling >>>
[docs] class UniformReparam(Reparam): """Reparameterize a distribution to a Uniform over the unit hypercube. Most univariate distribution uses inverse CDF for reparameterization. """ def __call__(self, name, fn, obs): if obs is not None: raise ValueError( 'UniformReparam does not support observe statements' ) shape = fn.shape() fn, expand_shape, event_dim = self._unwrap(fn) transform = uniform_reparam_transform(fn) tiny = jnp.finfo(jnp.result_type(float)).tiny x = numpyro.sample( name=f'u_{name}', fn=dist.Uniform(tiny, 1) .expand(shape) .to_event(event_dim) .mask(False), ) # Simulate a numpyro.deterministic() site. return None, transform(x)
[docs] @singledispatch def uniform_reparam_transform(d): """A helper for :class:`UniformReparam` to get the transform that maps a uniform distribution over a unit hypercube to the target distribution `d`. """ if isinstance(d, dist.TransformedDistribution): outer_transform = dist.transforms.ComposeTransform(d.transforms) def transform(q): return outer_transform(uniform_reparam_transform(d.base_dist)(q)) elif isinstance( d, dist.Independent | dist.ExpandedDistribution | dist.MaskedDistribution, ): def transform(q): return uniform_reparam_transform(d.base_dist)(q) else: transform = d.icdf return transform
@uniform_reparam_transform.register(dist.MultivariateNormal) def _(d): outer_transform = dist.transforms.LowerCholeskyAffine(d.loc, d.scale_tril) def transform(q): return outer_transform(dist.Normal(0, 1).icdf(q)) return transform @uniform_reparam_transform.register(dist.BernoulliLogits) @uniform_reparam_transform.register(dist.BernoulliProbs) def _(d): def transform(q): x = q < d.probs return jnp.astype(x, jnp.result_type(x, int)) return transform @uniform_reparam_transform.register(dist.CategoricalLogits) @uniform_reparam_transform.register(dist.CategoricalProbs) def _(d): def transform(q): return jnp.sum(jnp.cumsum(d.probs, axis=-1) < q[..., None], axis=-1) return transform @uniform_reparam_transform.register(dist.Dirichlet) def _(d): gamma_dist = dist.Gamma(d.concentration) def transform_fn(q): # NB: icdf is not available yet for Gamma distribution # so this will raise an NotImplementedError for now. # We will need scipy.special.gammaincinv, which is not available yet # in JAX, see issue: https://github.com/google/jax/issues/5350 # TODO: consider wrap jaxns GammaPrior transform implementation gammas = uniform_reparam_transform(gamma_dist)(q) return gammas / gammas.sum(-1, keepdims=True) return transform_fn
[docs] def uniform_reparam_model( model: Callable, model_args: tuple = (), model_kwargs: dict | None = None, rng_seed: int = 42, ) -> ModelInfo: seed_key, pred_key = random.split(random.PRNGKey(rng_seed)) if model_kwargs is None: model_kwargs = {} model_trace = trace(seed(model, seed_key)).get_trace( *model_args, **model_kwargs ) # params in constrained space params = { site['name']: site['value'] for site in model_trace.values() if ( (site['type'] == 'sample') and (not site['is_observed']) and (site['infer'].get('enumerate', '') != 'parallel') ) } params_names = list(params.keys()) params_dtype = [(k, v.dtype, v.shape) for k, v in params.items()] # deterministic sites deterministic_names = [ site['name'] for site in model_trace.values() if site['type'] == 'deterministic' ] # reparam the model so that latent sites have Uniform(0, 1) priors reparam_model = reparam( model, config={k: UniformReparam() for k in params_names} ) # hyper cube cube = {f'u_{v[0]}': jnp.full(v[2], 0.5, v[1]) for v in params_dtype} cube_ravel, unravel = ravel_pytree(cube) cube = jax.device_get(cube) cube_ravel = jax.device_get(cube_ravel) # enable enum if needed has_enum = any( site['type'] == 'sample' and site['infer'].get('enumerate', '') == 'parallel' for site in model_trace.values() ) if has_enum: from numpyro.contrib.funsor import enum, log_density as log_density_fn max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace) reparam_model = enum(reparam_model, -max_plate_nesting - 1) else: log_density_fn = log_density @jax.jit def log_prob_fn(params_cube): log_prob, _ = log_density_fn( reparam_model, model_args, model_kwargs, params_cube ) return log_prob @jax.jit def postprocess_fn(params_cube): return Predictive( reparam_model, params_cube, return_sites=params_names + deterministic_names, batch_ndims=0, )(pred_key, *model_args, **model_kwargs) samples = postprocess_fn(cube) deterministic = {i: samples[i] for i in deterministic_names} deterministic_dtype = [ (k, v.dtype, v.shape) for k, v in deterministic.items() ] return ModelInfo( ndim=len(cube_ravel), init=cube, init_ravel=cube_ravel, unravel=unravel, log_prob_fn=log_prob_fn, postprocess_fn=postprocess_fn, params_names=params_names, params_dtype=params_dtype, deterministic_names=deterministic_names, deterministic_dtype=deterministic_dtype, )
# <<< Codes above are adapted from numpyro.contrib.nested_sampling <<<