MNIST Classification with Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)

In this notebook we demonstrate how to use Fortuna to obtain predictions uncertainty estimates from a simple neural network model trained for MNIST classification task, using the SGHMC method.

Download MNIST data from TensorFlow

Let us first download the MNIST data from TensorFlow Datasets. Other sources would be equivalently fine.

[1]:
import tensorflow as tf
import tensorflow_datasets as tfds


def download(split_range, shuffle=False):
    ds = tfds.load(
        name="MNIST",
        split=f"train[{split_range}]",
        as_supervised=True,
        shuffle_files=True,
    ).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
    if shuffle:
        ds = ds.shuffle(10, reshuffle_each_iteration=True)
    return ds.batch(128).prefetch(1)


train_data_loader, val_data_loader, test_data_loader = (
    download(":80%", shuffle=True),
    download("80%:90%"),
    download("90%:"),
)
2025-04-23 21:26:51.973237: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-23 21:26:52.003774: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/docs/checkouts/readthedocs.org/user_builds/aws-fortuna/envs/latest/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Convert data to a compatible data loader

Fortuna helps you converting data and data loaders into a data loader that Fortuna can digest.

[2]:
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)

Build a probabilistic classifier

Let us build a probabilistic classifier. This is an interface object containing several attributes that you can configure, i.e. model, prior, and posterior_approximator. In this example, we use a multilayer perceptron and an SGHMC posterior approximator. SGHMC (and SGMCMC methods, broadly) allows configuring a step size schedule function. For simplicity, we create a constant step schedule.

[3]:
import flax.linen as nn

from fortuna.prob_model import ProbClassifier, SGHMCPosteriorApproximator
from fortuna.model import MLP

output_dim = 10
prob_model = ProbClassifier(
    model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
    posterior_approximator=SGHMCPosteriorApproximator(
        burnin_length=300, step_schedule=4e-6
    ),
)
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.

Train the probabilistic model: posterior fitting and calibration

We can now train the probabilistic model. This includes fitting the posterior distribution and calibrating the probabilistic model. We set the Markov chain burn-in phase to 20 epochs, followed by obtaining samples from the approximated posterior.

[4]:
from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer
from fortuna.metric.classification import accuracy

status = prob_model.train(
    train_data_loader=train_data_loader,
    val_data_loader=val_data_loader,
    calib_data_loader=val_data_loader,
    fit_config=FitConfig(
        monitor=FitMonitor(metrics=(accuracy,)),
        optimizer=FitOptimizer(n_epochs=30),
    ),
)
Epoch: 30 | loss: -29047.30664 | accuracy: 0.95312: 100%|██████████| 30/30 [00:54<00:00,  1.80s/it]
Epoch: 100 | loss: 1413.66968: 100%|██████████| 100/100 [00:26<00:00,  3.84it/s]

Estimate predictive statistics

We can now compute some predictive statistics by invoking the predictive attribute of the probabilistic classifier, and the method of interest. Most predictive statistics, e.g. mean or mode, require a loader of input data points. You can easily get this from the data loader calling its method to_inputs_loader.

[5]:
test_log_probs = prob_model.predictive.log_prob(data_loader=test_data_loader)
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
    inputs_loader=test_inputs_loader, means=test_means
)

Compute metrics

In classification, the predictive mode is a prediction for labels, while the predictive mean is a prediction for the probability of each label. As such, we can use these to compute several metrics, e.g. the accuracy, the Brier score, the expected calibration error (ECE), etc.

[6]:
from fortuna.metric.classification import (
    accuracy,
    expected_calibration_error,
    brier_score,
)

test_targets = test_data_loader.to_array_targets()
acc = accuracy(preds=test_modes, targets=test_targets)
brier = brier_score(probs=test_means, targets=test_targets)
ece = expected_calibration_error(
    preds=test_modes,
    probs=test_means,
    targets=test_targets,
    plot=True,
    plot_options=dict(figsize=(10, 2)),
)
print(f"Test accuracy: {acc}")
print(f"Brier score: {brier}")
print(f"ECE: {ece}")
Test accuracy: 0.9108332991600037
Brier score: 0.12804414331912994
ECE: 0.017098067328333855
../_images/examples_mnist_classification_sghmc_13_1.svg

Conformal prediction sets

Fortuna allows to produce conformal prediction sets, that are sets of likely labels up to some coverage probability threshold. These can be computed starting from probability estimates obtained with or without Fortuna.

[7]:
from fortuna.conformal import AdaptivePredictionConformalClassifier

val_means = prob_model.predictive.mean(inputs_loader=val_data_loader.to_inputs_loader())
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
    val_probs=val_means,
    test_probs=test_means,
    val_targets=val_data_loader.to_array_targets(),
    error=0.05,
)

We can check that, on average, conformal sets for misclassified inputs are larger than for well classified ones.

[8]:
import numpy as np

avg_size = np.mean([len(s) for s in np.array(conformal_sets, dtype="object")])
avg_size_wellclassified = np.mean(
    [
        len(s)
        for s in np.array(conformal_sets, dtype="object")[test_modes == test_targets]
    ]
)
avg_size_misclassified = np.mean(
    [
        len(s)
        for s in np.array(conformal_sets, dtype="object")[test_modes != test_targets]
    ]
)
print(f"Average conformal set size: {avg_size}")
print(
    f"Average conformal set size over well classified input: {avg_size_wellclassified}"
)
print(f"Average conformal set size over misclassified input: {avg_size_misclassified}")
Average conformal set size: 9.909666666666666
Average conformal set size over well classified input: 9.956816102470265
Average conformal set size over misclassified input: 9.42803738317757

Furthermore, we visualize some of the examples with the largest and the smallest conformal sets. Intutively, they correspond to the inputs where the model is the most uncertain or the most certain about its predictions.

[9]:
from matplotlib import pyplot as plt

N_EXAMPLES = 10
images = test_data_loader.to_array_inputs()


def visualize_examples(indices, n_examples=N_EXAMPLES):
    n_rows = min(len(indices), n_examples)
    _, axs = plt.subplots(1, n_rows, figsize=(10, 2))
    axs = axs.flatten()
    for i, ax in enumerate(axs):
        ax.imshow(images[indices[i]], cmap="gray")
        ax.axis("off")
    plt.show()
[10]:
indices = np.argsort([len(s) for s in np.array(conformal_sets, dtype="object")])
[11]:
print("Examples with the smallest conformal sets:")
visualize_examples(indices[:N_EXAMPLES])
Examples with the smallest conformal sets:
../_images/examples_mnist_classification_sghmc_21_1.svg
[12]:
print("Examples with the largest conformal sets:")
visualize_examples(np.flip(indices[-N_EXAMPLES:]))
Examples with the largest conformal sets:
../_images/examples_mnist_classification_sghmc_22_1.svg