Source code for fortuna.output_calibrator.classification
import flax.linen as nn
import jax.numpy as jnp
from fortuna.typing import Array
[docs]class ClassificationTemperatureScaler(nn.Module):
r"""
Classification temperature scaling. It scales the logits with a scalar temperature parameters. Let :math:`o` be
output logits and :math:`\phi` be a scalar parameter. Then the scaling can be seen as
:math:`g(\phi, o) = \exp(-\phi) o`.
"""
@nn.compact
def __call__(self, x: Array, **kwargs) -> jnp.ndarray:
log_temp = self.param("log_temp", nn.initializers.zeros, (1,))
return x * jnp.exp(-log_temp)