Source code for fortuna.conformal.classification.adaptive_conformal_classifier

import inspect
from typing import (
    List,
    Optional,
)

import jax.numpy as jnp

from fortuna.conformal.classification.base import ConformalClassifier
from fortuna.typing import Array


[docs]class AdaptiveConformalClassifier(ConformalClassifier): def __init__(self, conformal_classifier: ConformalClassifier): """ An adaptive conformal classifier class (see `Gibbs & Candes, 2021 <https://proceedings.neurips.cc/paper/2021/hash/0d441de75945e5acbc865406fc9a2559-Abstract.html>`_). It takes any conformal classifier and adds the functionality to update the coverage error to take into account distributional shifts in the data. Parameters ---------- conformal_classifier: ConformalClassifier A conformal method for classification. """ for s, m in inspect.getmembers(conformal_classifier): if not s.startswith("__"): setattr(self, s, m)
[docs] def update_error( self, conformal_set: List[int], error: float, target: Array, target_error: float, gamma: float = 0.005, weights: Optional[Array] = None, were_in: Optional[Array] = None, return_were_in: bool = False, ) -> Array: """ Update the coverage error based on the test target variable belonging or not to the conformal set. Parameters ---------- conformal_set: List[int] A conformal set for the current test target variable. error: float The current coverage error to update. target: Array The observed test target variable. target_error: float The target coverage error. gamma: float The step size for the coverage error update. weights: Optional[Array] Weights over the considered past time steps and the current one. This must be a one-dimensional array of increasing components between 0 and 1, summing up to 1. were_in: Optional[Array] It indicates whether the target variables of the considered past time steps fell within the respective conformal sets. This must be a one-dimensional array of 1's and 0's. Its length must be the length of `weights` minus one, as it refers to all the past time steps but not the current one. return_were_in: bool It returns an updated `were_in`, which includes whether the current test target variable falls within its conformal set. Returns ------- Array The updated coverage error. """ if gamma <= 0: raise ValueError( f"`gamma` must be a value greater than 0, but {gamma} was found." ) if weights is not None and were_in is None: raise ValueError( "If `weights` is available, `were_in` must be available too." ) if weights is None and were_in is not None: raise ValueError( "If `were_in` is available, `weights` must be available too." ) if weights is not None: if weights.ndim > 1: raise ValueError( "`weights` must be a one-dimensional array over the considered times in the time " "series." ) if ( jnp.any(weights[:-1] > weights[1:]) or jnp.any(weights < 0) or jnp.any(weights > 1) or not jnp.allclose(jnp.sum(weights), 1.0) ): raise ValueError( "`weights` must be a vector of weights sorted in ascending order, with all " "elements between 0 and 1, summing up to 1." ) if were_in is not None: if jnp.any((were_in != 0) * (were_in != 1)): raise ValueError("`were_in` must be a vector of 0's and 1's.") if were_in.ndim != 1: raise ValueError( "`were_in` must a be one-dimensional array over the considered times in the time " "series." ) if len(were_in) != len(weights) - 1: raise ValueError( "`len(weights)-1` and `len(were_in)` must be the same. " f"However, {len(weights) - 1} and {len(were_in)} were found, respectively." ) is_in = self.is_in(target[None], [conformal_set]) if were_in is not None: is_in = jnp.concatenate((were_in, is_in)) error += gamma * (target_error - jnp.dot(weights, 1 - is_in)) else: error += gamma * (target_error - 1 + is_in.squeeze()) if error > 1: error = 1 if error < 0: error = 0 if return_were_in: return float(error), is_in return error