Random Features#

class fortuna.model.utils.random_features.RandomFeatureGaussianProcess(features, hidden_features=1024, normalize_input=False, norm_kwargs=<factory>, hidden_kwargs=<factory>, output_kwargs=<factory>, covariance_kwargs=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A Gaussian process layer using random Fourier Features.

See Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness

features#

The number of output units.

Type:

int

hidden_features#

The number of hidden random fourier features.

Type:

int

normalize_input#

Whether to normalize the input using nn.LayerNorm.

Type:

bool

norm_kwargs#

Optional keyword arguments to the input nn.LayerNorm layer.

Type:

Mapping[str, Any]

hidden_kwargs#

Optional keyword arguments to the random feature layer.

Type:

Mapping[str, Any]

output_kwargs#

Optional keyword arguments to the predictive logit layer.

Type:

Mapping[str, Any]

covariance_kwargs#

Optional keyword arguments to the predictive covariance layer.

Type:

Mapping[str, Any]

class fortuna.model.utils.random_features.RandomFourierFeatures(features, kernel_scale=1.0, feature_scale=1.0, kernel_init=<function normal.<locals>.init>, bias_init=<function uniform.<locals>.init>, seed=0, dtype=<class 'jax.numpy.float32'>, collection_name='random_features', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A random fourier feature (RFF) layer that approximates a kernel model.

The random feature transformation is a one-hidden-layer network with non-trainable weights (see, e.g., Algorithm 1 of Random Features for Large-Scale Kernel Machines):

\[f(x) = \gamma * cos(\mathbf{W}\mathbf{x} + \mathbf{b})\]

where \(\mathbf{W}\) is the kernel matrix, \(\mathbf{b}\) is the bias and \(\gamma\) is the output scale. The forward pass logic closely follows that of the nn.Dense layer.

features#

The number of output units.

Type:

int

feature_scale#

Scale to apply to the output. When using GP layer as the output layer of a nerual network, it is recommended to set this to 1. to prevent it from changing the learning rate to the hidden layers.

Type:

Optional[float]

kernel_init#

Callable[[PRNGKeyArray, Shape, Type], Array] function for the weight matrix.

Type:

Callable[[PRNGKeyArray, Shape, Type], Array]

bias_init#

Callable[[PRNGKeyArray, Shape, Type], Array] function for the bias.

Type:

Callable[[PRNGKeyArray, Shape, Type], Array]

seed#

Random seed for generating random features. This will override the external RNGs.

Type:

int

dtype#

The dtype of the computation.

Type:

Type

class fortuna.model.utils.random_features.LaplaceRandomFeatureCovariance(hidden_features, ridge_penalty=1.0, momentum=None, collection_name='laplace_covariance', dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Computes the approximated posterior covariance using Laplace method.

hidden_features#

The number of random fourier features.

Type:

int

ridge_penalty#

Initial Ridge penalty to weight covariance matrix. This value is used to stabilize the eigenvalues of weight covariance estimate \(\Sigma\) so that the matrix inverse can be computed for \(\Sigma = (\mathbf{I}*s+\mathbf{X}^T\mathbf{X})^{-1}\). The ridge factor \(s\) cannot be too large since otherwise it will dominate making the covariance estimate not meaningful.

Type:

float

momentum#

A discount factor used to compute the moving average for posterior precision matrix. Analogous to the momentum factor in batch normalization. If None then update covariance matrix using a naive sum without momentum, which is desirable if the goal is to compute the exact covariance matrix by passing through data once (say in the final epoch). In this case, make sure to reset the precision matrix variable between epochs to avoid double counting.

Type:

Optional[float]

dtype#

The dtype of the computation

Type:

Type