# implementation adapted from https://github.com/google/edward2/blob/main/edward2/jax/nn/random_feature.py
import dataclasses
import functools
from typing import (
Any,
Callable,
Mapping,
Optional,
Tuple,
Type,
)
import flax.linen as nn
from jax import (
lax,
random,
)
import jax.numpy as jnp
from jax.random import PRNGKeyArray
from fortuna.typing import (
Array,
Shape,
)
linalg = lax.linalg
# Default config for random features.
default_rbf_bias_init = nn.initializers.uniform(scale=2.0 * jnp.pi)
# Using "he_normal" style random feature distribution (see https://arxiv.org/abs/1502.01852).
# Effectively, this is equivalent to approximating a RBF kernel but with the input standardized by
# its dimensionality (i.e., input_scaled = input * sqrt(2. / dim_input)) and
# empirically leads to better performance for neural network inputs.
# default_rbf_kernel_init = nn.initializers.variance_scaling(
# scale=2.0, mode="fan_in", distribution="normal"
# )
default_rbf_kernel_init = nn.initializers.normal(stddev=1.0)
# Default field value for kwargs, to be used for data class declaration.
default_kwarg_dict = lambda: dataclasses.field(default_factory=dict)
SUPPORTED_LIKELIHOOD = ("binary_logistic", "poisson", "gaussian")
[docs]class RandomFeatureGaussianProcess(nn.Module):
"""
A Gaussian process layer using random Fourier Features.
See `Simple and Principled Uncertainty Estimation with Deterministic
Deep Learning via Distance Awareness <https://arxiv.org/abs/2006.10108>`_
Attributes
----------
features: int
The number of output units.
hidden_features: int
The number of hidden random fourier features.
normalize_input: bool
Whether to normalize the input using nn.LayerNorm.
norm_kwargs: Mapping[str, Any]
Optional keyword arguments to the input nn.LayerNorm layer.
hidden_kwargs: Mapping[str, Any]
Optional keyword arguments to the random feature layer.
output_kwargs: Mapping[str, Any]
Optional keyword arguments to the predictive logit layer.
covariance_kwargs: Mapping[str, Any]
Optional keyword arguments to the predictive covariance layer.
"""
features: int
hidden_features: int = 1024
normalize_input: bool = False
# Optional keyword arguments.
norm_kwargs: Mapping[str, Any] = default_kwarg_dict()
hidden_kwargs: Mapping[str, Any] = default_kwarg_dict()
output_kwargs: Mapping[str, Any] = default_kwarg_dict()
covariance_kwargs: Mapping[str, Any] = default_kwarg_dict()
def setup(self):
# pylint:disable=invalid-name,not-a-mapping
if self.normalize_input:
# Prefer a parameter-free version of LayerNorm by default
# (see `Xu et al., 2019 <https://papers.nips.cc/paper/2019/file/2f4fe03d77724a7217006e5d16728874-Paper.pdf>`_)
# Can be overwritten by passing norm_kwargs=dict(use_bias=..., use_scales=...).
LayerNorm = functools.partial(nn.LayerNorm, use_bias=False, use_scale=False)
self.sngp_norm_layer = LayerNorm(**self.norm_kwargs)
self.sngp_random_features_layer = RandomFourierFeatures(
features=self.hidden_features,
**self.hidden_kwargs,
)
self.sngp_dense_layer = nn.Dense(features=self.features, **self.output_kwargs)
self.sngp_covariance_layer = LaplaceRandomFeatureCovariance(
hidden_features=self.hidden_features, **self.covariance_kwargs
)
# pylint:enable=invalid-name,not-a-mapping
def __call__(
self,
inputs: Array,
return_full_covariance: bool = False,
) -> Tuple[Array, Array]:
"""
Computes Gaussian process outputs.
Parameters
----------
inputs: Array
The nd-array of shape (batch_size, ..., input_dim).
return_full_covariance: bool
Whether to return the full covariance matrix, shape
(batch_size, batch_size), or only return the predictive variances with
shape (batch_size, ).
Returns
-------
Tuple[Array, Array]
A tuple of predictive logits, predictive covariance and (optionally)
random Fourier features.
"""
gp_inputs = self.sngp_norm_layer(inputs) if self.normalize_input else inputs
gp_features = self.sngp_random_features_layer(gp_inputs)
gp_logits = self.sngp_dense_layer(gp_features)
gp_covariance = self.sngp_covariance_layer(
gp_features, gp_logits, diagonal_only=not return_full_covariance
)
return gp_logits, gp_covariance
[docs]class RandomFourierFeatures(nn.Module):
"""
A random fourier feature (RFF) layer that approximates a kernel model.
The random feature transformation is a one-hidden-layer network with
non-trainable weights (see, e.g., Algorithm 1 of
`Random Features for Large-Scale
Kernel Machines <https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf>`_):
.. math::
f(x) = \gamma * cos(\mathbf{W}\mathbf{x} + \mathbf{b})
where :math:`\mathbf{W}` is the kernel matrix, :math:`\mathbf{b}` is the bias
and :math:`\gamma` is the output scale.
The forward pass logic closely follows that of the `nn.Dense` layer.
Attributes
----------
features: int
The number of output units.
feature_scale: Optional[float]
Scale to apply to the output.
When using GP layer as the output layer of a nerual network, it is recommended to set this to 1.
to prevent it from changing the learning rate to the hidden layers.
kernel_init: Callable[[PRNGKeyArray, Shape, Type], Array]
Callable[[PRNGKeyArray, Shape, Type], Array] function for the weight matrix.
bias_init: Callable[[PRNGKeyArray, Shape, Type], Array]
Callable[[PRNGKeyArray, Shape, Type], Array] function for the bias.
seed: int
Random seed for generating random features. This will override the external RNGs.
dtype: Type
The dtype of the computation.
"""
features: int
kernel_scale: Optional[float] = 1.0
feature_scale: Optional[float] = 1.0
kernel_init: Callable[[PRNGKeyArray, Shape, Type], Array] = default_rbf_kernel_init
bias_init: Callable[[PRNGKeyArray, Shape, Type], Array] = default_rbf_bias_init
seed: int = 0
dtype: Type = jnp.float32
collection_name: str = "random_features"
def setup(self):
# Defines the random number generator.
self.rng = random.PRNGKey(self.seed)
# Processes random feature scale.
self._feature_scale = self.feature_scale
if self._feature_scale is None:
self._feature_scale = jnp.sqrt(2.0 / self.features)
self._feature_scale = jnp.asarray(self._feature_scale, dtype=self.dtype)
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""
Applies random feature transformation along the last dimension of inputs.
Parameters
----------
inputs: Array
The nd-array to be transformed.
Returns
-------
Array
The transformed input.
"""
# Initializes variables.
input_dim = inputs.shape[-1]
kernel_rng, bias_rng = random.split(self.rng, num=2)
kernel_shape = (input_dim, self.features)
kernel = self.variable(
self.collection_name,
"kernel",
self.kernel_init,
kernel_rng,
kernel_shape,
self.dtype,
)
kernel_scale = self.param(
"rf_kernel_scale",
nn.initializers.constant(self.kernel_scale),
(1,),
self.dtype,
)
bias = self.variable(
self.collection_name,
"bias",
self.bias_init,
bias_rng,
(self.features,),
self.dtype,
)
# Specifies multiplication dimension.
contracting_dims = ((inputs.ndim - 1,), (0,))
batch_dims = ((), ())
# Performs forward pass.
inputs = jnp.asarray(inputs, self.dtype)
outputs = lax.dot_general(
inputs, (1.0 / kernel_scale) * kernel.value, (contracting_dims, batch_dims)
)
outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)
return self._feature_scale * jnp.cos(outputs)
[docs]class LaplaceRandomFeatureCovariance(nn.Module):
"""
Computes the approximated posterior covariance using Laplace method.
Attributes
----------
hidden_features: int
The number of random fourier features.
ridge_penalty: float
Initial Ridge penalty to weight covariance matrix.
This value is used to stabilize the eigenvalues of weight covariance estimate :math:`\Sigma` so that
the matrix inverse can be computed for :math:`\Sigma = (\mathbf{I}*s+\mathbf{X}^T\mathbf{X})^{-1}`.
The ridge factor :math:`s` cannot be too large since otherwise it will dominate
making the covariance estimate not meaningful.
momentum: Optional[float]
A discount factor used to compute the moving average for posterior
precision matrix. Analogous to the momentum factor in batch normalization.
If `None` then update covariance matrix using a naive sum without
momentum, which is desirable if the goal is to compute the exact
covariance matrix by passing through data once (say in the final epoch).
In this case, make sure to reset the precision matrix variable between
epochs to avoid double counting.
dtype: Type
The dtype of the computation
"""
hidden_features: int
ridge_penalty: float = 1.0
momentum: Optional[float] = None
collection_name: str = "laplace_covariance"
dtype: Type = jnp.float32
def setup(self):
if self.momentum is not None:
if self.momentum < 0.0 or self.momentum > 1.0:
raise ValueError(
f"`momentum` must be between (0, 1). " f"Got {self.momentum}."
)
@nn.compact
def __call__(
self,
gp_features: Array,
gp_logits: Optional[Array] = None,
diagonal_only: bool = True,
) -> Optional[Array]:
"""
Updates the precision matrix and computes the predictive covariance.
NOTE:
The precision matrix will be updated only during training (i.e., when
`self.collection_name` are in the list of mutable variables). The covariance
matrix will be computed only during inference to avoid repeated calls to the
(expensive) `linalg.inv` op.
Parameters
----------
gp_features: Array
The nd-array of random fourier features, shape (batch_size, ..., hidden_features).
gp_logits: Optional[Array]
The nd-array of predictive logits, shape (batch_size, ..., logit_dim).
Cannot be None.
diagonal_only: bool
Whether to return only the diagonal elements of the predictive covariance matrix (i.e., the predictive variance).
Returns
-------
Optional[Array]
The predictive variances of shape (batch_size, ) if diagonal_only=True,
otherwise the predictive covariance matrix of shape
(batch_size, batch_size).
"""
gp_features = jnp.asarray(gp_features, self.dtype)
# Flatten GP features and logits to 2-d, by doing so we treat all the
# non-final dimensions as the batch dimensions.
gp_features = jnp.reshape(gp_features, [-1, self.hidden_features])
if gp_logits is not None:
gp_logits = jnp.asarray(gp_logits, self.dtype)
gp_logits = jnp.reshape(gp_logits, [gp_features.shape[0], -1])
precision_matrix = self.variable(
self.collection_name,
"precision_matrix",
lambda: self.initial_precision_matrix(),
) # pylint: disable=unnecessary-lambda
# Updates the precision matrix during training.
initializing = self.is_mutable_collection("params")
training = self.is_mutable_collection(self.collection_name)
if training and not initializing:
precision_matrix.value = self.update_precision_matrix(
gp_features, gp_logits, precision_matrix.value
)
# Computes covariance matrix during inference.
if not training:
return self.compute_predictive_covariance(
gp_features, precision_matrix, diagonal_only
)
def initial_precision_matrix(self):
"""Returns the initial diagonal precision matrix."""
return jnp.eye(self.hidden_features, dtype=self.dtype) * self.ridge_penalty
def update_precision_matrix(
self, gp_features: Array, gp_logits: Optional[Array], precision_matrix: Array
) -> Array:
"""Updates precision matrix given a new batch.
Parameters
----------
gp_features: Array
Random features from the new batch, shape (batch_size, hidden_features)
gp_logits: Optional[Array]
Predictive logits from the new batch, shape (batch_size, logit_dim).
Currently only `logit_dim=1` is supported.
precision_matrix: Array
The current precision matrix, shape (hidden_features, hidden_features).
Returns
-------
Array
Updated precision matrix, shape (hidden_features, hidden_features).
"""
# Computes precision matrix within new batch.
prob_multiplier = 1.0
gp_features_adj = jnp.sqrt(prob_multiplier) * gp_features
batch_prec_mat = jnp.matmul(jnp.transpose(gp_features_adj), gp_features_adj)
# Updates precision matrix.
if self.momentum is None:
# Performs exact update without momentum.
precision_matrix_updated = precision_matrix + batch_prec_mat
else:
batch_size = gp_features.shape[0]
precision_matrix_updated = (
self.momentum * precision_matrix
+ (1 - self.momentum) * batch_prec_mat / batch_size
)
return precision_matrix_updated
def compute_predictive_covariance(
self, gp_features: Array, precision_matrix: nn.Variable, diagonal_only: bool
) -> Array:
"""
Computes the predictive covariance.
Approximates the Gaussian process posterior using random features.
Given training random feature :math:`\mathbf{\Phi_{tr}}` (num_train, num_hidden) and testing
random feature :math:`\mathbf{\Phi_{ts}}` (batch_size, num_hidden). The predictive covariance
matrix is computed as (assuming Gaussian likelihood):
:math:`s * \mathbf{\Phi_{ts}}(\mathbf{I}*s + \mathbf{\Phi_{tr}}^{T}*\mathbf{\Phi_{tr}})^{-1}\mathbf{\Phi_{tr}}^{^T}`
where :math:`s` is the ridge factor to be used for stablizing the inverse, and \mathbf{I} is
the identity matrix with shape (num_hidden, num_hidden). The above
description is formal only: the actual implementation uses a Cholesky
factorization of the covariance matrix.
Parameters
----------
gp_features: Array
The random feature of testing data to be used for computing the covariance matrix.
Shape (batch_size, gp_hidden_size).
precision_matrix: nn.Variable
The model's precision matrix.
diagonal_only: bool
Whether to return only the diagonal elements of the predictive covariance matrix
(i.e., the predictive variances).
Returns
-------
Array
The predictive variances of shape (batch_size, ) if `diagonal_only=True`,
otherwise the predictive covariance matrix of shape (batch_size, batch_size).
"""
chol = linalg.cholesky(precision_matrix.value)
chol_t_cov_feature_product = linalg.triangular_solve(
chol, gp_features.T, left_side=True, lower=True
)
if diagonal_only:
# Compute diagonal element only, shape (batch_size, ).
gp_covar = jnp.square(chol_t_cov_feature_product).sum(0)
else:
# Compute full covariance matrix, shape (batch_size, batch_size).
gp_covar = chol_t_cov_feature_product.T @ chol_t_cov_feature_product
return self.ridge_penalty * gp_covar