Source code for fortuna.conformal.regression.adaptive_conformal_regressor
import inspect
from typing import Optional
import jax.numpy as jnp
from fortuna.conformal.regression.base import ConformalRegressor
from fortuna.typing import Array
[docs]class AdaptiveConformalRegressor:
def __init__(self, conformal_regressor: ConformalRegressor):
"""
An adaptive conformal regressor class
(see `Gibbs & Candes, 2021 <https://proceedings.neurips.cc/paper/2021/hash/0d441de75945e5acbc865406fc9a2559-Abstract.html>`_).
It takes any conformal regressor and adds the functionality to update the coverage error to take into account
distributional shifts in the data.
Parameters
----------
conformal_regressor: ConformalRegressor
A conformal method for regression.
"""
for s, m in inspect.getmembers(conformal_regressor):
if not s.startswith("__"):
setattr(self, s, m)
[docs] def update_error(
self,
conformal_interval: Array,
error: float,
target: Array,
target_error: float,
gamma: float = 0.005,
weights: Optional[Array] = None,
were_in: Optional[Array] = None,
return_were_in: bool = False,
) -> Array:
"""
Update the coverage error based on the test target variable belonging or not to the conformal interval.
Parameters
----------
conformal_interval: List[int]
A conformal interval for the current test target variable.
error: float
The current coverage error to update.
target: Array
The observed test target variable.
target_error: float
The target coverage error.
gamma: float
The step size for the coverage error update.
weights: Optional[Array]
Weights over the considered past time steps and the current one.
This must be a one-dimensional array of increasing components between 0 and 1, summing up to 1.
were_in: Optional[Array]
It indicates whether the target variables of the considered past time steps fell within the respective
conformal intervals. This must be a one-dimensional array of 1's and 0's. Its length must be the length of
`weights` minus one, as it refers to all the past time steps but not the current one.
return_were_in: bool
It returns an updated `were_in`, which includes whether the current test target variable falls within its
conformal interval.
Returns
-------
Array
The updated coverage error.
"""
if gamma <= 0:
raise ValueError(
f"`gamma` must be a value greater than 0, but {gamma} was found."
)
if weights is not None and were_in is None:
raise ValueError(
"If `weights` is available, `were_in` must be available too."
)
if weights is None and were_in is not None:
raise ValueError(
"If `were_in` is available, `weights` must be available too."
)
if weights is not None:
if weights.ndim > 1:
raise ValueError(
"`weights` must be a one-dimensional array over the considered times in the time "
"series."
)
if (
jnp.any(weights[:-1] > weights[1:])
or jnp.any(weights < 0)
or jnp.any(weights > 1)
or not jnp.allclose(jnp.sum(weights), 1.0)
):
raise ValueError(
"`weights` must be a vector of weights sorted in ascending order, with all elements "
"between 0 and 1, summing up to 1."
)
if were_in is not None:
if jnp.any((were_in != 0) * (were_in != 1)):
raise ValueError("`were_in` must be a vector of 0's and 1's.")
if were_in.ndim != 1:
raise ValueError(
"`were_in` must a be one-dimensional array over the considered times in the time "
"series."
)
if len(were_in) != len(weights) - 1:
raise ValueError(
"`len(weights)-1` and `len(were_in)` must be the same. "
f"However, {len(weights) - 1} and {len(were_in)} were found, respectively."
)
is_in = self.is_in(target[None], conformal_interval[None])[0]
if were_in is not None:
is_in = jnp.concatenate((were_in, is_in))
error += gamma * (target_error - jnp.dot(weights, 1 - is_in))
else:
error += gamma * (target_error - 1 + is_in.squeeze())
if error > 1:
error = 1
if error < 0:
error = 0
if return_were_in:
return float(error), is_in
return error