Source code for fortuna.output_calibrator.regression
import flax.linen as nn
import jax.numpy as jnp
from fortuna.typing import Array
[docs]class RegressionTemperatureScaler(nn.Module):
r"""
Regression temperature scaling. It multiplies the variance with a scalar temperature parameters. Let :math:`v` be
the variance outputs and :math:`\phi` be a scalar parameter. Then the scaling can be seen as
:math:`g(\phi, o) = \exp(\phi) v`.
"""
@nn.compact
def __call__(self, x: Array, **kwargs) -> jnp.ndarray:
log_temp = self.param("log_temp", nn.initializers.zeros, (1,))
mean, log_var = jnp.split(x, 2, axis=-1)
log_var += log_temp
return jnp.concatenate((mean, log_var), axis=-1)