"""
Adapted from https://github.com/omegafragger/DDU/blob/main/utils/gmm_utils.py
"""
import logging
from typing import (
Callable,
Tuple,
)
import jax
from jax import numpy as jnp
import jax.scipy as jsp
import jax.scipy.stats as jsp_stats
import numpy as np
from fortuna.ood_detection.base import (
NotFittedError,
OutOfDistributionClassifierABC,
)
from fortuna.typing import Array
DOUBLE_INFO = np.finfo(np.double)
JITTERS = [0, DOUBLE_INFO.tiny] + [10**exp for exp in range(-10, 0, 1)]
def _centered_cov(x: Array) -> Array:
n = x.shape[0]
res = jnp.matmul(1 / (n - 1) * x.T, x)
return res
def compute_classwise_mean_and_cov(
embeddings: Array, labels: Array, num_classes: int
) -> Tuple[Array, Array]:
"""
Computes class-specific means and a covariance matrices given the training set embeddings
(e.g., the last-layer representation of the model for each training example).
Parameters
----------
embeddings: Array
The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size.
labels: Array
An array of length `n` containing, for each input sample, its ground-truth label.
num_classes: int
The total number of classes available in the classification task.
Returns
----------
Tuple[Array, Array]:
A tuple containing:
1) an `Array` containing the per-class mean vector of the fitted GMM.
The shape of the array is `(num_classes, d)`.
2) an `Array` containing the per-class covariance matrix of the fitted GMM.
The shape of the array is `(num_classes, d, d)`.
"""
#
classwise_mean_features = np.stack(
[jnp.mean(embeddings[labels == c], 0) for c in range(num_classes)]
)
#
classwise_cov_features = np.stack(
[
_centered_cov(embeddings[labels == c] - classwise_mean_features[c])
for c in range(num_classes)
]
)
return classwise_mean_features, classwise_cov_features
def _get_logpdf_fn(
classwise_mean_features: Array, classwise_cov_features: Array
) -> Callable[[Array], Array]:
"""
Returns a function to evaluate the log-likelihood of a test sample according to the (fitted) GMM.
Parameters
----------
classwise_mean_features: Array
The per-class mean vector of the fitted GMM. The shape of the array is `(num_classes, d)`.
classwise_cov_features: Array
The per-class covariance matrix of the fitted GMM. The shape of the array is `(num_classes, d, d)`.
Returns
-------
Callable[[Array], Array]
A function to evaluate the log-likelihood of a test sample according to the (fitted) GMM.
"""
for jitter_eps in JITTERS:
jitter = np.expand_dims(jitter_eps * np.eye(classwise_cov_features.shape[1]), 0)
gmm_logprob_fn_vmapped = jax.vmap(
jsp_stats.multivariate_normal.logpdf, in_axes=(None, 0, 0)
)
gmm_logprob_fn = lambda x: gmm_logprob_fn_vmapped(
x, classwise_mean_features, (classwise_cov_features + jitter)
).T
nans = np.isnan(gmm_logprob_fn(classwise_mean_features)).sum()
if nans > 0:
logging.info(f"Nans, jittering {jitter_eps}")
continue
break
return gmm_logprob_fn
[docs]class DeepDeterministicUncertaintyOODClassifier(OutOfDistributionClassifierABC):
"""
A Gaussian Mixture Model :math:`q(\mathbf{x}, z)` with a single Gaussian mixture component per class :math:`k \in {1,...,K}`
is fit after training.
Each class component is fit computing the empirical mean :math:`\mathbf{\hat{\mu}_k}` and covariance matrix
:math:`\mathbf{\hat{\Sigma}_k}` of the feature vectors :math:`f(\mathbf{x})`.
The confidence score :math:`M(\mathbf{x})` for a new test sample is obtained computing the negative marginal likelihood
of the feature representation.
See `Mukhoti, Jishnu, et al. <https://arxiv.org/abs/2102.11582>`_
"""
def __init__(self, *args, **kwargs):
super(DeepDeterministicUncertaintyOODClassifier, self).__init__(*args, **kwargs)
self._classwise_mean_features = None
self._classwise_cov_features = None
self._gmm_logpdf_fn = None
@property
def mean(self) -> Array:
"""
Returns
-------
Array
The per-class mean vector of the fitted GMM. The shape of the array is `(num_classes, d)`
where `num_classes` is the number of target classes in the in-distribution classification task
and `d` is the embedding size.
"""
return self._classwise_mean_features
@property
def cov(self):
"""
Returns
-------
Array
The per-class covariance matrix of the fitted GMM. The shape of the array is `(num_classes, d, d)`
where `num_classes` is the number of target classes in the in-distribution classification task
and `d` is the embedding size.
"""
return self._classwise_cov_features
[docs] def fit(self, embeddings: Array, targets: Array) -> None:
"""
Fits a Multivariate Gaussian to the training data using class-specific means and covariance matrix.
Parameters
----------
embeddings: Array
The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size.
targets: Array
An array of length `n` containing, for each input sample, its ground-truth label.
"""
(
self._classwise_mean_features,
self._classwise_cov_features,
) = compute_classwise_mean_and_cov(embeddings, targets, self.num_classes)
self._gmm_logpdf_fn = _get_logpdf_fn(
self._classwise_mean_features, self._classwise_cov_features
)
[docs] def score(self, embeddings: Array) -> Array:
"""
The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing
the negative marginal likelihood of the feature representation
:math:`-q(f(\mathbf{x})) = - \sum\limits_{k}q(f(\mathbf{x})|y) q(y)`.
A high score signals that the test sample :math:`\mathbf{x}` is identified as OOD.
Parameters
----------
embeddings: Array
The embeddings of shape `(n, d)` where `n` is the number of test samples and `d` is the embbeding's size.
Returns
-------
Array
An array of scores with length `n`.
"""
if self._gmm_logpdf_fn is None:
raise NotFittedError("You have to call fit before calling score.")
loglik = self._gmm_logpdf_fn(embeddings)
return -jsp.special.logsumexp(jnp.nan_to_num(loglik, 0.0), axis=1)