Source code for fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner

from typing import (
    Callable,
    NamedTuple,
)

import jax
import jax.numpy as jnp
from optax._src.base import PyTree

from fortuna.typing import Params

PreconditionerState = NamedTuple


class Preconditioner(NamedTuple):
    """A sampler preconditioner class.

    Attributes
    ----------
        init: Callable
            The state initialization function.
        update_preconditioner: Callable
            The state update function that takes gradients as an input.
        multiply_by_m_sqrt: Callable
            The function that multiples its input by the square root of mass matrix :math:`\sqrt{M}`.
        multiply_by_m_inv: Callable
            The function that multiples its input by the mass matrix inverse :math:`M^{-1}`.
        multiply_by_m_sqrt_inv: Callable
            The function that multiples its input by the square root of mass matrix inverse.

    """

    init: Callable[[Params], PreconditionerState]
    update_preconditioner: Callable[[PyTree, PreconditionerState], PreconditionerState]
    multiply_by_m_sqrt: Callable[[PyTree, PreconditionerState], PyTree]
    multiply_by_m_inv: Callable[[PyTree, PreconditionerState], PyTree]
    multiply_by_m_sqrt_inv: Callable[[PyTree, PreconditionerState], PyTree]


class RMSPropPreconditionerState(PreconditionerState):
    grad_moment_estimates: Params


[docs]def rmsprop_preconditioner(running_average_factor: float = 0.99, eps: float = 1.0e-7): """Create an instance of the adaptive RMSProp preconditioner. Parameters ---------- running_average_factor: float The decay factor for the squared gradients moving average. eps: float :math:`\epsilon` constant for numerical stability. Returns ------- preconditioner: Preconditioner An instance of RMSProp preconditioner. """ def init_fn(params): return RMSPropPreconditionerState( grad_moment_estimates=jax.tree_util.tree_map(jnp.zeros_like, params) ) def update_preconditioner_fn(gradient, preconditioner_state): grad_moment_estimates = jax.tree_util.tree_map( lambda e, g: e * running_average_factor + g**2 * (1 - running_average_factor), preconditioner_state.grad_moment_estimates, gradient, ) return RMSPropPreconditionerState(grad_moment_estimates=grad_moment_estimates) def multiply_by_m_inv_fn(vec, preconditioner_state): return jax.tree_util.tree_map( lambda e, v: v / (eps + jnp.sqrt(e)), preconditioner_state.grad_moment_estimates, vec, ) def multiply_by_m_sqrt_fn(vec, preconditioner_state): return jax.tree_util.tree_map( lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)), preconditioner_state.grad_moment_estimates, vec, ) def multiply_by_m_sqrt_inv_fn(vec, preconditioner_state): return jax.tree_util.tree_map( lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)), preconditioner_state.grad_moment_estimates, vec, ) return Preconditioner( init=init_fn, update_preconditioner=update_preconditioner_fn, multiply_by_m_inv=multiply_by_m_inv_fn, multiply_by_m_sqrt=multiply_by_m_sqrt_fn, multiply_by_m_sqrt_inv=multiply_by_m_sqrt_inv_fn, )
class IdentityPreconditionerState(PreconditionerState): pass
[docs]def identity_preconditioner(): """Create an instance of no-op identity preconditioner. Returns ------- preconditioner: Preconditioner An instance of identity preconditioner. """ def init_fn(_): return IdentityPreconditionerState() def update_preconditioner_fn(*args, **kwargs): return IdentityPreconditionerState() def multiply_by_m_inv_fn(vec, _): return vec def multiply_by_m_sqrt_fn(vec, _): return vec def multiply_by_m_sqrt_inv_fn(vec, _): return vec return Preconditioner( init=init_fn, update_preconditioner=update_preconditioner_fn, multiply_by_m_inv=multiply_by_m_inv_fn, multiply_by_m_sqrt=multiply_by_m_sqrt_fn, multiply_by_m_sqrt_inv=multiply_by_m_sqrt_inv_fn, )