Source code for fortuna.prob_model.posterior.sngp.sngp_callback

from flax.core import FrozenDict
import jax.numpy as jnp

from fortuna.training.callback import Callback
from fortuna.training.train_state import TrainState
from fortuna.utils.nested_dicts import (
    find_one_path_to_key,
    nested_get,
    nested_update,
)


[docs]class ResetCovarianceCallback(Callback): """ Reset, at the beginning of each epoch, the covariance matrix estimated while training an SNGP model. """ def __init__(self, precision_matrix_key_name: str, ridge_penalty: float): self.precision_matrix_key_name = precision_matrix_key_name self.ridge_penalty = ridge_penalty
[docs] def training_epoch_start(self, state: TrainState) -> TrainState: key_paths = find_one_path_to_key(state.mutable, self.precision_matrix_key_name) precision_matrix = nested_get(state.mutable, key_paths) if precision_matrix.ndim == 2: n, _ = precision_matrix.shape # rows, cols init_precision_matrix = ( jnp.eye(n, dtype=precision_matrix.dtype) * self.ridge_penalty ) elif precision_matrix.ndim == 3: d, n, _ = precision_matrix.shape # num_devices, rows, cols init_precision_matrix = ( jnp.eye(n, dtype=precision_matrix.dtype) * self.ridge_penalty ) init_precision_matrix = jnp.broadcast_to(init_precision_matrix, (d, n, n)) partially_updated_mutables = init_precision_matrix for key in reversed(key_paths): partially_updated_mutables = {key: partially_updated_mutables} mutables = nested_update(state.mutable.unfreeze(), partially_updated_mutables) mutables = FrozenDict(mutables) return state.replace(mutable=mutables)