Spectral-normalized Neural Gaussian Process (SNGP)#

class fortuna.prob_model.posterior.sngp.sngp_approximator.SNGPPosteriorApproximator(*args, output_dim, gp_hidden_features=1024, normalize_input=False, ridge_penalty=1.0, momentum=None, mean_field_factor=1.0, **kwargs)[source]#

SNGP posterior approximator. It is responsible to define how the posterior distribution is approximated.

Parameters:
  • output_dim (int) – The output dimension of the network.

  • normalize_input (bool) – Whether to normalize the input using nn.LayerNorm.

  • gp_hidden_features (int) – The number of random fourier features.

  • ridge_penalty (float) – Initial Ridge penalty to weight covariance matrix. This value is used to stabilize the eigenvalues of weight covariance estimate \(\Sigma\) so that the matrix inverse can be computed for \(\Sigma = (\mathbf{I}*s+\mathbf{X}^T\mathbf{X})^{-1}\). The ridge factor \(s\) cannot be too large since otherwise it will dominate making the covariance estimate not meaningful.

  • momentum (Optional[float]) – A discount factor used to compute the moving average for posterior precision matrix. Analogous to the momentum factor in batch normalization. If None then update covariance matrix using a naive sum without momentum, which is desirable if the goal is to compute the exact covariance matrix by passing through data once (say in the final epoch). In this case, make sure to reset the precision matrix variable between epochs to avoid double counting.

  • mean_field_factor (float) – The scale factor for mean-field approximation, used to adjust (at inference time) the influence of posterior variance in posterior mean approximation. See Zhiyun L. et al., 2020 for more details.

  • mean_field_factor

    The scale factor for mean-field approximation, used to adjust (at inference time) the influence of posterior variance in posterior mean approximation. See Zhiyun L. et al., 2020 for more details.

property posterior_method_kwargs: Dict[str, Any]#
class fortuna.prob_model.posterior.sngp.sngp_posterior.SNGPPosterior(joint, posterior_approximator)[source]#

Bases: MAPPosterior

Spectral-normalized Neural Gaussian Process (SNGP) approximate posterior class.

Parameters:
fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), **kwargs)[source]#

Fit the posterior distribution. A posterior state will be internally stored.

Parameters:
  • train_data_loader (DataLoader) – Training data loader.

  • val_data_loader (Optional[DataLoader]) – Validation data loader.

  • fit_config (FitConfig) – A configuration object.

Returns:

A status including metrics describing the fitting process.

Return type:

Status

class fortuna.model.model_manager.classification.SNGPClassificationModelManager(model, *args, **kwargs)[source]#

Classification model manager for SNGP models.

Parameters:

model (nn.Module) – A model describing the deterministic relation between inputs and outputs. The outputs of the model is the latent representation of the input, which in this case, does not correspond to the logits of a softmax probability vector. The output dimension of the model is not dependent on the number of classes in the classification task.

apply(params, inputs, mutable=None, train=False, rng=None)#

Apply the models’ forward pass.

Parameters:
  • params (Params) – The random parameters of the probabilistic model.

  • inputs (InputData) – Input data points.

  • mutable (Optional[Mutable]) – The mutable objects used to evaluate the models.

  • train (bool) – Whether the method is called during training.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

The output of the model manager for each input. Mutable objects may also be returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]

init(input_shape, rng=None, **kwargs)[source]#

Initialize random parameters and mutable objects.

Parameters:
  • input_shape (Tuple) – The shape of the input variable.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

Initialized random parameters and mutable objects.

Return type:

Dict[str, FrozenDict]

property rng: RandomNumberGenerator#

Invoke the random number generator object.

Return type:

The random number generator object.

class fortuna.prob_model.posterior.sngp.sngp_callback.ResetCovarianceCallback(precision_matrix_key_name, ridge_penalty)[source]#

Reset, at the beginning of each epoch, the covariance matrix estimated while training an SNGP model.

training_epoch_end(state)#

Called at the end of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_epoch_start(state)[source]#

Called at the beginning of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_step_end(state)#

Called after every minibatch update

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState