Source code for fortuna.prob_model.posterior.sngp.sngp_posterior

import logging
from typing import Optional

from fortuna.data import DataLoader
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior
from fortuna.prob_model.posterior.sngp import SNGP_NAME
from fortuna.prob_model.posterior.sngp.sngp_approximator import (
    SNGPPosteriorApproximator,
)
from fortuna.prob_model.posterior.sngp.sngp_callback import ResetCovarianceCallback
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import Status
from fortuna.utils.nested_dicts import find_one_path_to_key

logger = logging.getLogger(__name__)


[docs]class SNGPPosterior(MAPPosterior): def __init__( self, joint: Joint, posterior_approximator: SNGPPosteriorApproximator, ): """ Spectral-normalized Neural Gaussian Process (`SNGP <https://arxiv.org/abs/2006.10108>`_) approximate posterior class. Parameters ---------- joint: Joint A Joint distribution object. posterior_approximator: SNGPPosteriorApproximator An SNGP posterior approximator. """ super().__init__(joint=joint, posterior_approximator=posterior_approximator)
[docs] def fit( self, train_data_loader: DataLoader, val_data_loader: Optional[DataLoader] = None, fit_config: FitConfig = FitConfig(), **kwargs, ) -> Status: # set sngp callback to reset covariance callbacks = [ ResetCovarianceCallback( precision_matrix_key_name="precision_matrix", ridge_penalty=self.joint.likelihood.model_manager.ridge_penalty, ) ] if fit_config.callbacks is None: fit_config.callbacks = callbacks else: fit_config.callbacks = fit_config.callbacks + callbacks return super(SNGPPosterior, self).fit( train_data_loader, val_data_loader, fit_config, **kwargs )
def __str__(self): return SNGP_NAME @staticmethod def _check_state(state: PosteriorState) -> None: path = find_one_path_to_key(state.mutable, "spectral_stats") if len(path) == 0: raise ValueError( f"It looks like your deep feature extractor does not have Spectral Normalization, " f"which is required by SNGP. Please include spectral normalization in your model." f"Check out `fortuna.model.utils.spectral_norm.WithSpectralNorm` for more details." )