How to calibrate sub-networks of pre-trained models#

Fortuna’s calibration model offer a simple interface to train or fine-tune a deep learning model. The user is free to choose a custom calibration loss, select an optimizer, decide which parameters to train or fine-tune, monitor calibration metrics, and more. In this example, we look at some of its main functionalities.

Download, split and process the data#

First, we download the data from TensorFlow, and split them into training, validation and test set.

[1]:
import tensorflow as tf
import tensorflow_datasets as tfds


def download(split_range, shuffle=False):
    ds = tfds.load(
        name="mnist",
        split=f"train[{split_range}]",
        as_supervised=True,
        shuffle_files=True,
    ).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
    if shuffle:
        ds = ds.shuffle(10, reshuffle_each_iteration=True)
    return ds.batch(128).prefetch(1)


train_data_loader, val_data_loader, test_data_loader = (
    download(":80%", shuffle=True),
    download("80%:90%"),
    download("90%:"),
)
2024-03-08 16:08:49.403439: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We then convert the data loaders into something that Fortuna can work with.

[2]:
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)

Define and calibrate the calibration model#

We now introduce CalibClassifier, i.e. Fortuna’s calibration classifier purposed to obtain calibrated predictions.

[3]:
from fortuna.model import LeNet5
from fortuna.calib_model import CalibClassifier
[4]:
calib_model = CalibClassifier(model=LeNet5(output_dim=10))
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise please install the optional 'transformers' dependency of fortuna.Using poetry, you can achieve this by entering: `poetry install --extras "transformers"`
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Let’s calibrate this model! At first, we will run the calibration from scratch, thus this can just be seen as training the model. By default, the calibration exploits a focal loss Mukhoti et al., 2020 with gamma=2., but other custom losses may be used. During the calibration, we will enable early stopping and monitor accuracy and Brier score - we will just have to adust the signature to make sure it is compatible with one that the CalibClassifier expects.

[5]:
from fortuna.calib_model import Config, Monitor
from fortuna.metric.classification import brier_score, accuracy


def brier(preds, uncertainties, targets):
    return brier_score(uncertainties, targets)


def acc(preds, uncertainties, targets):
    return accuracy(preds, targets)


status = calib_model.calibrate(
    train_data_loader,
    val_data_loader=val_data_loader,
    config=Config(monitor=Monitor(early_stopping_patience=2, metrics=(brier, acc))),
)
Epoch: 4 | loss: 0.03045 | brier: 0.0414 | acc: 0.97656:   3%|▎         | 3/100 [02:05<1:07:50, 41.96s/it]

Expected calibration error and reliability plot#

In one go, let’s compute the Expected Calibration Error (ECE) and draw a reliability plot! To obtain this, we need to first obtain predictions and their probabilities over the test data set.

[6]:
test_inputs_loader = test_data_loader.to_inputs_loader()
preds = calib_model.predictive.mode(test_inputs_loader)
probs = calib_model.predictive.mean(test_inputs_loader)
[7]:
from fortuna.metric.classification import expected_calibration_error
[8]:
test_targets = test_data_loader.to_array_targets()
ece = expected_calibration_error(
    preds, probs, test_targets, plot=True, plot_options=dict(figsize=(6, 2))
)
print(f"ECE: {ece}.")
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:3613: UserWarning: 'kind' argument to argsort is ignored; only 'stable' sorts are supported.
  warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
ECE: 0.026803817600011826.
../_images/examples_subnet_calibration_17_2.svg

Expect for very low confidence, where usually we do not have enough information to obtain a reliable ECE, the model seems well calibrated, since the difference between confidence and accuracy is close to 0 for most confidence bins.

Calibrate only a subset of model parameters#

With the only purpose of demonstrating the functionality, let us now show how you can start from a pre-trained model and fine-tune only a subset of model parameters, perhaps with the purpose of achieving better calibration.

All you need to do is pass freeze_fun to the Optimizer in the Config object, and declare which parameters you want to be trainable and which frozen. In this example, the parameters of the LeNet-5 model in use are internally organized in a deep feature extractor sub-network (dfe_subnet) and an output sub-network. Then we simply freeze dfe_subnet and let the model fine-tune only the output layer.

In order to start from the pre-trained state, we simply enable the flag start_from_current_state in the Checkpointer.

[9]:
from fortuna.calib_model import Optimizer, Checkpointer
[10]:
status = calib_model.calibrate(
    train_data_loader,
    val_data_loader=val_data_loader,
    config=Config(
        monitor=Monitor(early_stopping_patience=2, metrics=(brier, acc)),
        checkpointer=Checkpointer(start_from_current_state=True),
        optimizer=Optimizer(
            freeze_fun=lambda path, v: "frozen" if "dfe_subnet" in path else "trainable"
        ),
    ),
)
Epoch: 4 | loss: 0.00126 | brier: 0.00587 | acc: 1.0:   3%|▎         | 3/100 [01:24<45:40, 28.25s/it]