Random Features

class fortuna.model.utils.random_features.RandomFeatureGaussianProcess(features, hidden_features=1024, normalize_input=False, norm_kwargs=<factory>, hidden_kwargs=<factory>, output_kwargs=<factory>, covariance_kwargs=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A Gaussian process layer using random Fourier Features.

See Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness

features

The number of output units.

Type:

int

hidden_features

The number of hidden random fourier features.

Type:

int

normalize_input

Whether to normalize the input using nn.LayerNorm.

Type:

bool

norm_kwargs

Optional keyword arguments to the input nn.LayerNorm layer.

Type:

Mapping[str, Any]

hidden_kwargs

Optional keyword arguments to the random feature layer.

Type:

Mapping[str, Any]

output_kwargs

Optional keyword arguments to the predictive logit layer.

Type:

Mapping[str, Any]

covariance_kwargs

Optional keyword arguments to the predictive covariance layer.

Type:

Mapping[str, Any]

class fortuna.model.utils.random_features.RandomFourierFeatures(features, kernel_scale=1.0, feature_scale=1.0, kernel_init=<function normal.<locals>.init>, bias_init=<function uniform.<locals>.init>, seed=0, dtype=<class 'jax.numpy.float32'>, collection_name='random_features', parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A random fourier feature (RFF) layer that approximates a kernel model.

The random feature transformation is a one-hidden-layer network with non-trainable weights (see, e.g., Algorithm 1 of Random Features for Large-Scale Kernel Machines):

\[f(x) = \gamma * cos(\mathbf{W}\mathbf{x} + \mathbf{b})\]

where \(\mathbf{W}\) is the kernel matrix, \(\mathbf{b}\) is the bias and \(\gamma\) is the output scale. The forward pass logic closely follows that of the nn.Dense layer.

features

The number of output units.

Type:

int

feature_scale

Scale to apply to the output. When using GP layer as the output layer of a nerual network, it is recommended to set this to 1. to prevent it from changing the learning rate to the hidden layers.

Type:

Optional[float]

kernel_init

Callable[[jax.Array, Shape, Type], Array] function for the weight matrix.

Type:

Callable[[jax.Array, Shape, Type], Array]

bias_init

Callable[[jax.Array, Shape, Type], Array] function for the bias.

Type:

Callable[[jax.Array, Shape, Type], Array]

seed

Random seed for generating random features. This will override the external RNGs.

Type:

int

dtype

The dtype of the computation.

Type:

Type

class fortuna.model.utils.random_features.LaplaceRandomFeatureCovariance(hidden_features, ridge_penalty=1.0, momentum=None, collection_name='laplace_covariance', dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Computes the approximated posterior covariance using Laplace method.

hidden_features

The number of random fourier features.

Type:

int

ridge_penalty

Initial Ridge penalty to weight covariance matrix. This value is used to stabilize the eigenvalues of weight covariance estimate \(\Sigma\) so that the matrix inverse can be computed for \(\Sigma = (\mathbf{I}*s+\mathbf{X}^T\mathbf{X})^{-1}\). The ridge factor \(s\) cannot be too large since otherwise it will dominate making the covariance estimate not meaningful.

Type:

float

momentum

A discount factor used to compute the moving average for posterior precision matrix. Analogous to the momentum factor in batch normalization. If None then update covariance matrix using a naive sum without momentum, which is desirable if the goal is to compute the exact covariance matrix by passing through data once (say in the final epoch). In this case, make sure to reset the precision matrix variable between epochs to avoid double counting.

Type:

Optional[float]

dtype

The dtype of the computation

Type:

Type