Source code for fortuna.conformal.classification.simple_prediction
from typing import (
List,
Optional,
)
from jax import vmap
import jax.numpy as jnp
from fortuna.conformal.classification.base import ConformalClassifier
from fortuna.typing import Array
[docs]class SimplePredictionConformalClassifier(ConformalClassifier):
[docs] def score(
self,
val_probs: Array,
val_targets: Array,
) -> jnp.ndarray:
"""
Compute score function.
Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
val_targets: Array
A one-dimensional array of validation target variables.
Returns
-------
jnp.ndarray
The conformal scores.
"""
if val_probs.ndim != 2:
raise ValueError(
"""`val_probs` must be a two-dimensional array. The first dimension is over the validation
inputs. The second is over the classes."""
)
@vmap
def score_fn(prob, target):
return 1 - prob[target]
return score_fn(val_probs, val_targets)
[docs] def quantile(
self,
val_probs: Array,
val_targets: Array,
error: float = 0.05,
scores: Optional[Array] = None,
) -> Array:
"""
Compute a quantile of the scores.
Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
val_targets: Array
A one-dimensional array of validation target variables.
error: float
Coverage error. This must be a scalar between 0 and 1, extremes included.
scores: Optional[Array]
The conformal scores. This should be the output of
:meth:`~fortuna.conformal.classification.simple_prediction.SimplePredictionConformalClassifier.score`.
Returns
-------
float
The conformal quantiles.
"""
if error < 0 or error > 1:
raise ValueError("""`error` must be a scalar between 0 and 1.""")
if scores is None:
scores = self.score(val_probs, val_targets)
n = scores.shape[0]
return jnp.quantile(scores, jnp.ceil((n + 1) * (1 - error)) / n)
[docs] def conformal_set(
self,
val_probs: Array,
test_probs: Array,
val_targets: Array,
error: float = 0.05,
quantile: Optional[float] = None,
) -> List[List[int]]:
"""
Coverage set of each of the test inputs, at the desired coverage error.
Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
test_probs: Array
A two-dimensional array of class probabilities for each test data point.
val_targets: Array
A one-dimensional array of validation target variables.
error: float
The coverage error. This must be a scalar between 0 and 1, extremes included.
quantile: Optional[float]
Conformal quantiles. This should be the output of
:meth:`~fortuna.conformal.classification.simple_prediction.SimplePredictionConformalClassifier.quantile`.
Returns
-------
List[List[int, ...]]
The coverage sets.
"""
if test_probs.ndim != 2:
raise ValueError(
"""`test_probs` must be a two-dimensional array. The first dimension is over the validation
inputs. The second is over the classes."""
)
if quantile is None:
quantile = self.quantile(val_probs, val_targets, error)
return [jnp.where(prob > 1 - quantile)[0].tolist() for prob in test_probs]