from typing import (
Dict,
Optional,
Tuple,
Union,
)
import jax.nn
import jax.numpy as jnp
from fortuna.data.loader import TargetsLoader
from fortuna.plot import plot_reliability_diagram
from fortuna.typing import Array
[docs]def accuracy(preds: Array, targets: Array) -> jnp.ndarray:
"""
Compute the accuracy given predictions and target variables.
Parameters
----------
preds: Array
A one-dimensional array of predictions over the data points.
targets: Array
A one-dimensional array of target variables.
Returns
-------
jnp.ndarray
The computed accuracy.
"""
if preds.ndim > 1:
raise ValueError(
"""`preds` must be a one-dimensional array of predicted classes."""
)
if targets.ndim > 1:
raise ValueError(
"""`targets` must be a one-dimensional array of target classes."""
)
return jnp.mean(preds == targets)
def compute_counts_confs_accs(
preds: Array,
probs: Array,
targets: Array,
plot: bool = False,
plot_options: Optional[Dict] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
Bin the confidence scores (maximum probability) and for each of them compute:
- the number of inputs;
- the average confidence score for each bin;
- the average accuracy over each bin.
Parameters
----------
preds: Array
A one-dimensional array of predictions over the data points.
probs: Array
A two-dimensional array of class probabilities for each data point.
targets: Array
A one-dimensional array of target variables.
plot: bool
Whether to plot a reliability diagram.
plot_options: dict
Options for the reliability diagram plot; see :func:`~fortuna.plot.plot_reliability_diagram`.
Returns
-------
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Number of inputs per bin, average confidence score per bin and average accuracy per bin.
"""
if probs.ndim == 2:
probs = probs.max(-1)
thresholds = jnp.linspace(1 / len(probs), 1, 10)
probs = jnp.array(probs)
indices = [jnp.where(probs <= thresholds[0])[0]]
indices += [
jnp.where((probs <= thresholds[i]) & (probs > thresholds[i - 1]))[0]
for i in range(1, len(thresholds))
]
counts = jnp.array([len(idx) for idx in indices])
diff = targets - preds
accs = jnp.array([jnp.nan_to_num(jnp.mean(diff[idx] == 0)) for idx in indices])
confs = jnp.array([jnp.nan_to_num(jnp.mean(probs[idx])) for idx in indices])
if plot:
idx = confs != 0
if plot_options is None:
plot_options = dict()
plot_reliability_diagram(accs[idx], confs[idx], **plot_options)
return counts, confs, accs
[docs]def expected_calibration_error(
preds: Array,
probs: Array,
targets: Array,
plot: bool = False,
plot_options: Optional[Dict] = None,
) -> jnp.ndarray:
"""
Compute the Expected Calibration Error (ECE)
(see `Naeini et al., 2015 <https://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf>`__ and
`Guo et al., 2017 <http://proceedings.mlr.press/v70/guo17a/guo17a.pdf>`__). Optionally, plot and save a reliability
diagram.
Parameters
----------
preds: Array
A one-dimensional array of predictions over the data points.
probs: Array
A two-dimensional array of class probabilities for each data point.
targets: Array
A one-dimensional array of target variables.
plot: bool
Whether to plot a reliability diagram.
plot_options: dict
Options for the reliability diagram plot; see :func:`~fortuna.plot.plot_reliability_diagram`.
Returns
-------
jnp.ndarray
The value of the ECE.
"""
counts, confs, accs = compute_counts_confs_accs(
preds, probs, targets, plot, plot_options
)
ece = jnp.sum(counts * jnp.abs(accs - confs)) / preds.shape[0]
return ece
[docs]def ece(
preds: Array,
probs: Array,
targets: Array,
plot: bool = False,
plot_options: Optional[Dict] = None,
) -> jnp.ndarray:
"""See :func:`.expected_calibration_error`."""
return expected_calibration_error(preds, probs, targets, plot, plot_options)
[docs]def maximum_calibration_error(
preds: Array,
probs: Array,
targets: Array,
plot: bool = False,
plot_options: Optional[Dict] = None,
) -> jnp.ndarray:
"""
Compute the Maximum Calibration Error (MCE)
(see `Naeini et al., 2015 <https://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf>`__). Optionally, plot
and save a reliability diagram.
Parameters
----------
preds: Array
A one-dimensional array of predictions over the data points.
probs: Array
A two-dimensional array of class probabilities for each data point.
targets: Array
A one-dimensional array of target variables.
plot: bool
Whether to plot a reliability diagram.
plot_options: dict
Options for the reliability diagram plot; see :func:`~fortuna.plot.plot_reliability_diagram`.
Returns
-------
jnp.ndarray
The value of the MCE.
"""
counts, confs, accs = compute_counts_confs_accs(
preds, probs, targets, plot, plot_options
)
mce = jnp.max(counts * jnp.abs(accs - confs))
return mce
[docs]def mce(
preds: Array,
probs: Array,
targets: Array,
plot: bool = False,
plot_options: Optional[Dict] = None,
) -> jnp.ndarray:
"""See :func:`.maximum_calibration_error`."""
return maximum_calibration_error(preds, probs, targets, plot, plot_options)
[docs]def brier_score(probs: Array, targets: Union[TargetsLoader, Array]) -> jnp.ndarray:
"""
Brier score (see `Brier, 1950 <https://web.archive.org/web/20171023012737/
https://docs.lib.noaa.gov/rescue/mwr/078/mwr-078-01-0001.pdf>`__). This can be used for both binary and multi-class
classification.
Parameters
----------
probs: Array
A one- or two-dimensional array of class probabilities for each data point.
targets: Array
A one-dimensional array of target variables.
Returns
-------
jnp.ndarray
The Brier score.
"""
if probs.ndim > 2:
raise ValueError("`probs` can be at most 2 dimensional.")
if type(targets) == TargetsLoader:
targets = targets.to_array_targets()
if probs.ndim > 1:
targets = jax.nn.one_hot(targets, probs.shape[-1])
return jnp.mean(jnp.sum((probs - targets) ** 2, axis=-1))
return jnp.mean((probs - targets) ** 2)