Source code for fortuna.output_calib_model.classification
from typing import (
Callable,
Optional,
)
import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from fortuna.loss.classification.focal_loss import focal_loss_fn
from fortuna.output_calib_model.base import OutputCalibModel
from fortuna.output_calib_model.config.base import Config
from fortuna.output_calib_model.predictive.classification import (
ClassificationPredictive,
)
from fortuna.output_calibrator.classification import ClassificationTemperatureScaler
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_output_layer.classification import ClassificationProbOutputLayer
from fortuna.typing import (
Array,
Outputs,
Status,
Targets,
)
[docs]
class OutputCalibClassifier(OutputCalibModel):
def __init__(
self,
output_calibrator: Optional[nn.Module] = ClassificationTemperatureScaler(),
seed: int = 0,
) -> None:
r"""
A calibration classifier class.
Parameters
----------
output_calibrator : Optional[nn.Module]
An output calibrator object. The default is temperature scaling for classification, which rescales the
logits with a scalar temperature parameter. Given outputs :math:`o`,
the output calibrator is described by a function :math:`g(\phi, o)`,
where `phi` are calibration parameters.
seed: int
A random seed.
Attributes
----------
output_calibrator : nn.Module
See `output_calibrator` in `Parameters`.
output_calib_manager : OutputCalibManager
It manages the forward pass of the output calibrator.
prob_output_layer : ClassificationProbOutputLayer
A probabilistic output payer.
It characterizes the distribution of the target variables given the outputs.
predictive : ClassificationPredictive
The predictive distribution.
"""
self.output_calibrator = output_calibrator
self.output_calib_manager = OutputCalibManager(
output_calibrator=output_calibrator
)
self.prob_output_layer = ClassificationProbOutputLayer()
self.predictive = ClassificationPredictive(
output_calib_manager=self.output_calib_manager,
prob_output_layer=self.prob_output_layer,
)
super().__init__(seed=seed)
[docs]
def calibrate(
self,
calib_outputs: Array,
calib_targets: Array,
val_outputs: Optional[Array] = None,
val_targets: Optional[Array] = None,
loss_fn: Callable[[Outputs, Targets], jnp.ndarray] = focal_loss_fn,
config: Config = Config(),
) -> Status:
"""
Calibrate the model outputs.
Parameters
----------
calib_outputs: Array
Calibration model outputs.
calib_targets: Array
Calibration target variables.
val_outputs: Optional[Array]
Validation model outputs.
val_targets: Optional[Array]
Validation target variables.
loss_fn: Callable[[Outputs, Targets], jnp.ndarray]
The loss function to use for calibration.
config : Config
An object to configure the calibration.
Returns
-------
Status
A calibration status object. It provides information about the calibration.
"""
self._check_output_dim(calib_outputs, calib_targets)
if val_outputs is not None:
self._check_output_dim(val_outputs, val_targets)
return super()._calibrate(
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean
),
calib_outputs=calib_outputs,
calib_targets=calib_targets,
val_outputs=val_outputs,
val_targets=val_targets,
loss_fn=loss_fn,
config=config,
)
@staticmethod
def _check_output_dim(outputs: jnp.ndarray, targets: jnp.array):
n_classes = len(np.unique(targets))
if outputs.shape[1] != n_classes:
raise ValueError(
f"""`outputs.shape[1]` must be the same as the dimension of the number of classes in `targets`.
However, `outputs.shape[1]={outputs.shape[1]}` and `len(np.unique(targets))={n_classes}`."""
)