Source code for fortuna.likelihood.classification

from typing import Optional

import jax
from jax import vmap
import jax.numpy as jnp

from fortuna.data.loader import InputsLoader
from fortuna.likelihood.base import Likelihood
from fortuna.model.model_manager.classification import ClassificationModelManager
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_output_layer.classification import ClassificationProbOutputLayer
from fortuna.typing import (
    Array,
    CalibMutable,
    CalibParams,
    Mutable,
    Params,
)


[docs]class ClassificationLikelihood(Likelihood): def __init__( self, model_manager: ClassificationModelManager, prob_output_layer: ClassificationProbOutputLayer, output_calib_manager: Optional[OutputCalibManager] = None, ): """ A classification likelihood function class. In this class, the likelihood function is additionally assumed to be a probability density function, i.e. positive and integrating to 1. The likelihood is formed by three objects applied in sequence: the model manager, the output calibrator and the probabilistic output layer. The model manager maps parameters and inputs to outputs. The output calibration takes outputs and returns some calibrated version of them. The probabilistic output layer describes the probability distribution of the calibrated outputs. Parameters ---------- model_manager : ModelManager An model manager. This objects orchestrates the evaluation of the models. prob_output_layer : ProbOutputLayer A probabilistic output layer object. This object characterizes the probability distribution of the target variable given the calibrated outputs. output_calib_manager : Optional[OutputCalibManager] An output calibration manager object. It transforms outputs of the model manager into some calibrated version of them. """ super().__init__( model_manager, prob_output_layer, output_calib_manager=output_calib_manager )
[docs] def mean( self, params: Params, inputs_loader: InputsLoader, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, distribute: bool = True, **kwargs, ) -> jnp.ndarray: r""" Estimate the likelihood mean of the one-hot encoded target variable, that is .. math:: \mathbb{E}_{\tilde{Y}|w, x}[\tilde{Y}], where: - :math:`x` is an observed input variable; - :math:`\tilde{Y}` is a one-hot encoded random target variable; - :math:`w` denotes the observed model parameters. Parameters ---------- params : Params The random parameters of the probabilistic model. inputs_loader : InputsLoader A loader of input data points. mutable : Optional[Mutable] The mutable objects used to evaluate the models. calib_params : Optional[CalibParams] The calibration parameters of the probabilistic model. calib_mutable : Optional[CalibMutable] The calibration mutable objects used to evaluate the calibrators. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray An estimate of the likelihood mean for each input. """ return super().mean( params, inputs_loader, mutable, calib_params, calib_mutable, distribute, **kwargs, )
def _batched_mean( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, **kwargs, ) -> jnp.ndarray: outputs = self._get_batched_calibrated_outputs( params, inputs, mutable, calib_params, calib_mutable, **kwargs ) return jax.nn.softmax(outputs, -1) def _batched_mode( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, **kwargs, ) -> jnp.ndarray: outputs = self._get_batched_calibrated_outputs( params, inputs, mutable, calib_params, calib_mutable, **kwargs ) return jnp.argmax(outputs, -1)
[docs] def variance( self, params: Params, inputs_loader: InputsLoader, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, distribute: bool = True, **kwargs, ) -> jnp.ndarray: r""" Estimate the likelihood variance of the one-hot encoded target variable, that is .. math:: \text{Var}_{\tilde{Y}|w,x}[\tilde{Y}], where: - :math:`x` is an observed input variable; - :math:`\tilde{Y}` is a one-hot encoded random target variable; - :math:`w` denotes the observed model parameters. Parameters ---------- params : Params The random parameters of the probabilistic model. inputs_loader : InputsLoader A loader of input data points. mutable : Optional[Mutable] The mutable objects used to evaluate the models. calib_params : Optional[CalibParams] The calibration parameters of the probabilistic model. calib_mutable : Optional[CalibMutable] The calibration mutable objects used to evaluate the calibrators. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray An estimate of the likelihood variance for each input. """ return super().variance( params, inputs_loader, mutable, calib_params, calib_mutable, distribute, **kwargs, )
def _batched_variance( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, distribute: bool = True, **kwargs, ) -> jnp.ndarray: means = self._batched_mean( params=params, inputs=inputs, mutable=mutable, calib_params=calib_params, calib_mutable=calib_mutable, **kwargs, ) return means * (1 - means)
[docs] def entropy( self, params: Params, inputs_loader: InputsLoader, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, distribute: bool = True, **kwargs, ) -> jnp.ndarray: outputs = super().get_calibrated_outputs( params, inputs_loader, mutable, calib_params, calib_mutable, distribute ) n_classes = outputs.shape[-1] @vmap def _entropy_term(i: int): targets = i * jnp.ones(outputs.shape[0]) log_liks = self.prob_output_layer.log_prob(outputs, targets, **kwargs) return jnp.exp(log_liks) * log_liks return -jnp.sum(_entropy_term(jnp.arange(n_classes)), 0)