Source code for fortuna.conformal.classification.simple_prediction

from jax import vmap

from fortuna.conformal.classification.base import (
    CVPlusConformalClassifier,
    SplitConformalClassifier,
)
from fortuna.typing import Array


@vmap
def _score_fn(probs: Array, target: Array):
    return 1 - probs[target]


[docs]def score_fn( probs: Array, targets: Array, ): return _score_fn(probs, targets)
[docs]class SimplePredictionConformalClassifier(SplitConformalClassifier):
[docs] def score_fn( self, probs: Array, targets: Array, ): return score_fn(probs=probs, targets=targets)
[docs]class CVPlusSimplePredictionConformalClassifier(CVPlusConformalClassifier):
[docs] def score_fn( self, probs: Array, targets: Array, ): return score_fn(probs=probs, targets=targets)