Source code for fortuna.ood_detection.mahalanobis

import logging
from typing import Tuple

import jax
import jax.numpy as jnp

from fortuna.ood_detection.base import (
    NotFittedError,
    OutOfDistributionClassifierABC,
)
from fortuna.typing import Array


@jax.jit
def compute_mean_and_joint_cov(
    embeddings: jnp.ndarray, labels: jnp.ndarray, class_ids: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Computes class-specific means and a shared covariance matrix given the training set embeddings
    (e.g., the last-layer representation of the model for each training example).

    Parameters
    ----------
    embeddings: jnp.ndarray
        An array of shape `(n, d)`, where `n` is the sample size of training set,
        `d` is the dimension of the embeddings.
    labels: jnp.ndarray
        An array of shape `(n,)`
    class_ids: jnp.ndarray
        An array of the unique class ids in `labels`.

    Returns
    -------
    Tuple[jnp.ndarray, jnp.ndarray]
        A tuple containing:
        1) A `jnp.ndarray` of len n_class, and the i-th element is an np.array of size
      ` (d,)` corresponding to the mean of the fitted Gaussian distribution for the i-th class;
        2) The shared covariance matrix of shape `(d, d)`.
    """
    n_dim = embeddings.shape[1]
    cov = jnp.zeros((n_dim, n_dim))

    def f(cov, class_id):
        mask = jnp.expand_dims(labels == class_id, axis=-1)
        data = embeddings * mask
        mean = jnp.sum(data, axis=0) / jnp.sum(mask)
        diff = (data - mean) * mask
        cov += jnp.matmul(diff.T, diff)
        return cov, mean

    cov, means = jax.lax.scan(f, cov, class_ids)
    cov = cov / len(labels)
    return means, cov


@jax.jit
def compute_mahalanobis_distance(
    embeddings: jnp.ndarray, means: jnp.ndarray, cov: jnp.ndarray
) -> jnp.ndarray:
    """
    Computes Mahalanobis distance between the input and the fitted Guassians.

    Parameters
    ----------
    embeddings: jnp.ndarray
        A matrix of shape `(n, d)`, where `n` is the sample size of the test set, and `d` is the size of the embeddings.
    means: jnp.ndarray
        A matrix of shape `(c, d)`, where `c` is the number of classes in the classification task.
        The ith row of the matrix corresponds to the mean of the fitted Gaussian distribution for the i-th class.
    cov: jnp.ndarray
        The shared covariance mmatrix of the shape `(d, d)`.

    Returns
    -------
    A matrix of size `(n, c)` where the `(i, j)` element
    corresponds to the Mahalanobis distance between i-th sample to the j-th
    class Gaussian.
    """
    # NOTE: It's possible for `cov` to be singular, in part because it is
    # estimated on a sample of data. This can be exacerbated by lower precision,
    # where, for example, the matrix could be non-singular in float64, but
    # singular in float32. For our purposes in computing Mahalanobis distance,
    # using a  pseudoinverse is a reasonable approach that will be equivalent to
    # the inverse if `cov` is non-singular.
    cov_inv = jnp.linalg.pinv(cov)

    def maha_dist(x, mean):
        # NOTE: This computes the squared Mahalanobis distance.
        diff = x - mean
        return jnp.einsum("i,ij,j->", diff, cov_inv, diff)

    maha_dist_all_classes_fn = jax.vmap(maha_dist, in_axes=(None, 0))
    out = jax.lax.map(lambda x: maha_dist_all_classes_fn(x, means), embeddings)
    return out


[docs]class MalahanobisOODClassifier(OutOfDistributionClassifierABC): """ The pre-trained features of a softmax neural classifier :math:`f(\mathbf{x})` are assumed to follow a class-conditional gaussian distribution with a tied covariance matrix :math:`\mathbf{\Sigma}`: .. math:: \mathbb{P}(f(\mathbf{x})|y=k) = \mathcal{N}(f(\mathbf{x})|\mu_k, \mathbf{\Sigma}) for all :math:`k \in {1,...,K}`, where :math:`K` is the number of classes. The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing the max (squared) Mahalanobis distance between :math:`f(\mathbf{x})` and the fitted class-wise guassians. See `Lee, Kimin, et al. <https://proceedings.neurips.cc/paper/2018/file/abdeb6f575ac5c6676b747bca8d09cc2-Paper.pdf>`_ """ def __init__(self, *args, **kwargs): super(MalahanobisOODClassifier, self).__init__(*args, **kwargs) self._mean = None self._cov = None @property def mean(self): """ Returns ------- Array A matrix of shape `(num_classes, d)`, where `num_classes` is the number of classes in the in-distribution classification task. The ith row of the matrix corresponds to the mean of the fitted Gaussian distribution for the i-th class. """ return self._mean @property def cov(self): """ Returns ------- Array The shared covariance matrix with shape `(d, d)`, where `d` is the embedding size. """ return self._cov
[docs] def fit(self, embeddings: Array, targets: Array) -> None: """ Fits a Multivariate Gaussian to the training data using class-specific means and a shared covariance matrix. Parameters ---------- embeddings: Array The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size. targets: Array An array of length `n` containing, for each input sample, its ground-truth label. """ n_labels_observed = len(jnp.unique(targets)) if n_labels_observed != self.num_classes: logging.warning( f"{self.num_classes} labels were expected but found {n_labels_observed} in the provided train set. " f"Will proceed but performance may be hurt by this." ) self._mean, self._cov = compute_mean_and_joint_cov( embeddings, targets, jnp.arange(self.num_classes) )
[docs] def score(self, embeddings: Array) -> Array: """ The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing the max (squared) Mahalanobis distance between :math:`f(\mathbf{x})` and the fitted class-wise Guassians. A high score signals that the test sample :math:`\mathbf{x}` is identified as OOD. Parameters ---------- embeddings: Array The embeddings of shape `(n, d)` where `n` is the number of test samples and `d` is the embbeding's size. Returns ------- Array An array of scores with length `n`. """ if self._mean is None or self._cov is None: raise NotFittedError("You have to call fit before calling score.") return compute_mahalanobis_distance(embeddings, self.mean, self.cov).min(axis=1)