How to calibrate sub-networks of pre-trained models

Fortuna’s calibration model offer a simple interface to train or fine-tune a deep learning model. The user is free to choose a custom calibration loss, select an optimizer, decide which parameters to train or fine-tune, monitor calibration metrics, and more. In this example, we look at some of its main functionalities.

Download, split and process the data

First, we download the data from TensorFlow, and split them into training, validation and test set.

[1]:
import tensorflow as tf
import tensorflow_datasets as tfds


def download(split_range, shuffle=False):
    ds = tfds.load(
        name="mnist",
        split=f"train[{split_range}]",
        as_supervised=True,
        shuffle_files=True,
    ).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
    if shuffle:
        ds = ds.shuffle(10, reshuffle_each_iteration=True)
    return ds.batch(128).prefetch(1)


train_data_loader, val_data_loader, test_data_loader = (
    download(":80%", shuffle=True),
    download("80%:90%"),
    download("90%:"),
)
2025-04-23 21:30:17.060102: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-23 21:30:17.089554: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We then convert the data loaders into something that Fortuna can work with.

[2]:
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)

Define and calibrate the calibration model

We now introduce CalibClassifier, i.e. Fortuna’s calibration classifier purposed to obtain calibrated predictions.

[3]:
from fortuna.model import LeNet5
from fortuna.calib_model import CalibClassifier
[4]:
calib_model = CalibClassifier(model=LeNet5(output_dim=10))
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise please install the optional 'transformers' dependency of fortuna.Using poetry, you can achieve this by entering: `poetry install --extras "transformers"`

Let’s calibrate this model! At first, we will run the calibration from scratch, thus this can just be seen as training the model. By default, the calibration exploits a focal loss Mukhoti et al., 2020 with gamma=2., but other custom losses may be used. During the calibration, we will enable early stopping and monitor accuracy and Brier score - we will just have to adust the signature to make sure it is compatible with one that the CalibClassifier expects.

[5]:
from fortuna.calib_model import Config, Monitor
from fortuna.metric.classification import brier_score, accuracy


def brier(preds, uncertainties, targets):
    return brier_score(uncertainties, targets)


def acc(preds, uncertainties, targets):
    return accuracy(preds, targets)


status = calib_model.calibrate(
    train_data_loader,
    val_data_loader=val_data_loader,
    config=Config(monitor=Monitor(early_stopping_patience=2, metrics=(brier, acc))),
)
Epoch: 6 | loss: 0.09192 | brier: 0.06985 | acc: 0.96094:   5%|▌         | 5/100 [00:58<18:31, 11.70s/it]

Expected calibration error and reliability plot

In one go, let’s compute the Expected Calibration Error (ECE) and draw a reliability plot! To obtain this, we need to first obtain predictions and their probabilities over the test data set.

[6]:
test_inputs_loader = test_data_loader.to_inputs_loader()
preds = calib_model.predictive.mode(test_inputs_loader)
probs = calib_model.predictive.mean(test_inputs_loader)
[7]:
from fortuna.metric.classification import expected_calibration_error
[8]:
test_targets = test_data_loader.to_array_targets()
ece = expected_calibration_error(
    preds, probs, test_targets, plot=True, plot_options=dict(figsize=(6, 2))
)
print(f"ECE: {ece}.")
ECE: 0.021739089861512184.
../_images/examples_subnet_calibration_17_1.svg

Expect for very low confidence, where usually we do not have enough information to obtain a reliable ECE, the model seems well calibrated, since the difference between confidence and accuracy is close to 0 for most confidence bins.

Calibrate only a subset of model parameters

With the only purpose of demonstrating the functionality, let us now show how you can start from a pre-trained model and fine-tune only a subset of model parameters, perhaps with the purpose of achieving better calibration.

All you need to do is pass freeze_fun to the Optimizer in the Config object, and declare which parameters you want to be trainable and which frozen. In this example, the parameters of the LeNet-5 model in use are internally organized in a deep feature extractor sub-network (dfe_subnet) and an output sub-network. Then we simply freeze dfe_subnet and let the model fine-tune only the output layer.

In order to start from the pre-trained state, we simply enable the flag start_from_current_state in the Checkpointer.

[9]:
from fortuna.calib_model import Optimizer, Checkpointer
[10]:
status = calib_model.calibrate(
    train_data_loader,
    val_data_loader=val_data_loader,
    config=Config(
        monitor=Monitor(early_stopping_patience=2, metrics=(brier, acc)),
        checkpointer=Checkpointer(start_from_current_state=True),
        optimizer=Optimizer(
            freeze_fun=lambda path, v: "frozen" if "dfe_subnet" in path else "trainable"
        ),
    ),
)
Epoch: 4 | loss: 0.01912 | brier: 0.0237 | acc: 0.98438:   3%|▎         | 3/100 [00:18<09:50,  6.09s/it]