Source code for fortuna.prob_model.prior.gaussian

from typing import Optional

from jax import random
from jax._src.prng import PRNGKeyArray
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

from fortuna.prob_model.prior.base import Prior
from fortuna.typing import Params


[docs]class IsotropicGaussianPrior(Prior): def __init__(self, log_var: Optional[float] = 0.0): """ A diagonal Gaussian prior class. Parameters ---------- log_var : Optional[float] Prior log-variance value. The covariance matrix of the prior distribution is given by a diagonal matrix with this parameter on every entry of the diagonal. """ super().__init__() self.log_var = log_var self.prec = jnp.exp(-self.log_var) self.std = jnp.exp(0.5 * self.log_var) self.log2pi = jnp.log(2 * jnp.pi)
[docs] def log_joint_prob(self, params: Params) -> float: rav = ravel_pytree(params)[0] n = len(rav) return -0.5 * (self.prec * jnp.sum(rav**2) + n * (self.log2pi + self.log_var))
[docs] def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params: dummy_rav, unravel = ravel_pytree(params_like) n = len(dummy_rav) if rng is None: rng = self.rng.get() rav_samples = self.std * random.normal(rng, shape=(n,)) return unravel(rav_samples)
[docs]class DiagonalGaussianPrior(Prior): def __init__(self, log_var: jnp.ndarray): """ A diagonal Gaussian prior class. Parameters ---------- log_var : jnp.ndarray Prior log-variance vector corresponding to the logarithm of the diagonal of the prior covariance matrix. """ super().__init__() self.log_var = log_var self.log2pi = jnp.log(2 * jnp.pi)
[docs] def log_joint_prob(self, params: Params) -> float: rav = ravel_pytree(params)[0] return -0.5 * jnp.sum( jnp.exp(-self.log_var) * rav**2 + self.log2pi + self.log_var )
[docs] def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params: dummy_rav, unravel = ravel_pytree(params_like) n = len(dummy_rav) if rng is None: rng = self.rng.get() rav_samples = jnp.exp(0.5 * self.log_var) * random.normal(rng, shape=(n,)) return unravel(rav_samples)