Source code for fortuna.prob_model.posterior.sngp.sngp_approximator

from typing import (
    Any,
    Dict,
    Optional,
)

from fortuna.prob_model.posterior.base import PosteriorApproximator
from fortuna.prob_model.posterior.sngp import SNGP_NAME


[docs]class SNGPPosteriorApproximator(PosteriorApproximator): def __init__( self, *args, output_dim: int, gp_hidden_features: int = 1024, normalize_input: bool = False, ridge_penalty: float = 1.0, momentum: Optional[float] = None, mean_field_factor: float = 1.0, **kwargs, ): """ SNGP posterior approximator. It is responsible to define how the posterior distribution is approximated. Parameters ---------- output_dim: int The output dimension of the network. normalize_input: bool Whether to normalize the input using nn.LayerNorm. gp_hidden_features: int The number of random fourier features. ridge_penalty: float Initial Ridge penalty to weight covariance matrix. This value is used to stabilize the eigenvalues of weight covariance estimate :math:`\Sigma` so that the matrix inverse can be computed for :math:`\Sigma = (\mathbf{I}*s+\mathbf{X}^T\mathbf{X})^{-1}`. The ridge factor :math:`s` cannot be too large since otherwise it will dominate making the covariance estimate not meaningful. momentum: Optional[float] A discount factor used to compute the moving average for posterior precision matrix. Analogous to the momentum factor in batch normalization. If `None` then update covariance matrix using a naive sum without momentum, which is desirable if the goal is to compute the exact covariance matrix by passing through data once (say in the final epoch). In this case, make sure to reset the precision matrix variable between epochs to avoid double counting. mean_field_factor: float The scale factor for mean-field approximation, used to adjust (at inference time) the influence of posterior variance in posterior mean approximation. See `Zhiyun L. et al., 2020 <https://arxiv.org/abs/2006.07584>`_ for more details. mean_field_factor: float The scale factor for mean-field approximation, used to adjust (at inference time) the influence of posterior variance in posterior mean approximation. See `Zhiyun L. et al., 2020 <https://arxiv.org/abs/2006.07584>`_ for more details. """ super(SNGPPosteriorApproximator, self).__init__(*args, **kwargs) self.output_dim = output_dim self.gp_hidden_features = gp_hidden_features self.normalize_input = normalize_input self.ridge_penalty = ridge_penalty self.momentum = momentum self.mean_field_factor = mean_field_factor def __str__(self): return SNGP_NAME @property def posterior_method_kwargs(self) -> Dict[str, Any]: return { "output_dim": self.output_dim, "gp_hidden_features": self.gp_hidden_features, "normalize_input": self.normalize_input, "ridge_penalty": self.ridge_penalty, "momentum": self.momentum, "mean_field_factor": self.mean_field_factor, }