Source code for fortuna.prob_output_layer.classification

from typing import Optional

import jax
from jax import (
    random,
    vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
import jax.scipy as jsp

from fortuna.prob_output_layer.base import ProbOutputLayer
from fortuna.typing import Array


[docs]class ClassificationProbOutputLayer(ProbOutputLayer): def __init__(self): r""" Classification probabilistic output layers class. It characterizes the probability distribution of a target variable given a calibrated output logits as a Categorical distribution. That is :math:`p(y|\omega)=\text{Categorical}(y|p=\text{softmax}(\omega))`, where :math:`y` denotes a target variable and :math:`\omega` a calibrated output. """ super().__init__()
[docs] def log_prob(self, outputs: Array, targets: Array, **kwargs) -> Array: n_cats = outputs.shape[-1] targets = jax.nn.one_hot(targets, n_cats) return jnp.sum(targets * outputs, -1) - jsp.special.logsumexp(outputs, -1)
[docs] def predict(self, outputs: Array, **kwargs) -> jnp.ndarray: return jnp.argmax(outputs, -1)
[docs] def sample( self, n_target_samples: int, outputs: Array, rng: Optional[PRNGKeyArray] = None, **kwargs, ) -> jnp.ndarray: probs = ( jax.nn.softmax(outputs, -1) if "probs" not in kwargs or kwargs["probs"] is None else kwargs["probs"] ) n_cats = probs.shape[-1] if rng is None: rng = self.rng.get() keys = random.split(rng, probs.shape[0]) return vmap( lambda key, p: random.choice(key, n_cats, p=p, shape=(n_target_samples,)), out_axes=1, )(keys, probs)
[docs] def mean(self, outputs: Array, **kwargs) -> jnp.ndarray: """ Estimate the mean of the one-hot encoded target variable given the output with respect to the probabilistic output layer distribution. Parameters ---------- outputs : Array Model outputs Returns ------- jnp.ndarray The estimated mean for each output. """ return jax.nn.softmax(outputs, -1)
[docs] def mode(self, outputs: Array, **kwargs) -> jnp.ndarray: """ Estimate the mode of the one-hot encoded target variable given the output with respect to the probabilistic output layer distribution. Parameters ---------- outputs : Array Model outputs Returns ------- jnp.ndarray The estimated mode for each output. """ return jnp.argmax(outputs, -1)
[docs] def variance(self, outputs: Array, **kwargs) -> jnp.ndarray: """ Estimate the variance of the one-hot encoded target variable given the output with respect to the probabilistic output layer distribution. Parameters ---------- outputs : Array Model outputs Returns ------- jnp.ndarray The estimated variance for each output. """ p = self.mean(outputs) return p * (1 - p)
[docs] def std(self, outputs: Array, variances: Optional[Array] = None) -> jnp.ndarray: """ Estimate the standard deviation of the one-hot encoded target variable given the output with respect to the probabilistic output layer distribution. Parameters ---------- outputs : Array Model outputs variances: Optional[Array] Variance for each output. Returns ------- jnp.ndarray The estimated standard deviation for each output. """ return super().std(outputs, variances)
[docs] def entropy(self, outputs: Array, **kwargs) -> jnp.ndarray: """ Estimate the entropy of the one-hot encoded target variable given the output with respect to the probabilistic output layer distribution. Parameters ---------- outputs : Array Model outputs Returns ------- jnp.ndarray The estimated entropy for each output. """ n_classes = outputs.shape[-1] @vmap def _entropy_term(i: int): targets = i * jnp.ones(outputs.shape[0]) log_liks = self.log_prob(outputs, targets, **kwargs) return jnp.exp(log_liks) * log_liks return -jnp.sum(_entropy_term(jnp.arange(n_classes)), 0)
[docs]class ClassificationMaskedProbOutputLayer(ClassificationProbOutputLayer):
[docs] def log_prob(self, outputs: Array, targets: Array, **kwargs) -> Array: n_cats = outputs.shape[-1] targets_mask = jnp.where(targets > 0, 1.0, 0.0) targets = jax.nn.one_hot(targets, n_cats) return ( jnp.sum(targets * outputs, -1) - jsp.special.logsumexp(outputs, -1) ) * targets_mask