Source code for fortuna.conformal.regression.cvplus
from typing import List
import jax.numpy as jnp
from fortuna.conformal.regression.base import ConformalRegressor
from fortuna.typing import Array
[docs]class CVPlusConformalRegressor(ConformalRegressor):
"""
This class implements the CV+ method introduced in
`Barber et al., 2021 <https://www.stat.cmu.edu/~ryantibs/papers/jackknife.pdf>`__. It is an extension of the
jackknife+ method, introduced in the same work, that consider a K-Fold instead of a leave-one-out strategy. If
:code:`K=n`, where :code:`n` is the total number of training data, then CV+ reduces to jackknife+.
"""
[docs] def conformal_interval(
self,
cross_val_outputs: List[Array],
cross_val_targets: List[Array],
cross_test_outputs: List[Array],
error: float,
) -> jnp.ndarray:
"""
Coverage interval of each of the test inputs, at the desired coverage error. This is supported only for
one-dimensional target variables.
Parameters
----------
cross_val_outputs: List[Array]
Outputs of the models used during cross validation evaluated at their respective validation inputs. More
precisely, we assume the training data has been jointly partitioned in :code:`K` subsets. The i-th element
of the list of :code: `cross_val_outputs` is a model trained on all data but the i-th partition, and has
been evaluated at the inputs of the partition i-th itself, for :code:`i=1, 2, ..., K`.
cross_val_targets: List[Array]
Target variables organized in the same partitions used for `cross_val_outputs`. More precisely, the i-th
element of :code:`cross_val_targets` includes the array of target variables of the i-th partition of the
training data, for :code:`i=1, 2, ..., K`.
cross_test_outputs: List[Array]
Outputs of the models used during cross validation evaluated at the test inputs. More precisely, consider
the same partition of data as the one used for :code:`cross_val_outputs`. Then the i-th element of
:code:`cross_test_outputs` represents the outputs of the model that has been trained upon all the training
data but the i-th partition, and evaluated at the test inputs, for :code:`i=1, 2, ..., K`.
error: float
The desired coverage error. This must be a scalar between 0 and 1, extremes included.
Returns
-------
jnp.ndarray
The conformal intervals. The two components of the second axis correspond to the left and right interval
bounds.
"""
if type(cross_val_outputs) != list:
raise TypeError("`cross_val_outputs` must be a list of arrays.")
if type(cross_val_targets) != list:
raise TypeError("`cross_val_targets` must be a list of arrays.")
if type(cross_test_outputs) != list:
raise TypeError("`cross_test_outputs` must be a list of arrays.")
for i, (mu, y, mu_test) in enumerate(
zip(cross_val_outputs, cross_val_targets, cross_test_outputs)
):
if mu.shape[0] != y.shape[0]:
raise ValueError(
"The first dimension of the i-th element in `cross_val_outputs` must be the same as "
"the one of the i-th element in `cross_val_targets`."
)
if mu.ndim == 1:
cross_val_outputs[i] = mu[:, None]
elif mu.shape[1] != 1:
raise ValueError(
"This method is supported only for scalar model outputs only. However, an element of "
"`cross_val_outputs` has second dimension greater than 1."
)
if y.ndim == 1:
cross_val_targets[i] = y[:, None]
elif y.shape[1] != 1:
raise ValueError(
"This method is supported only for scalar target variables. However, an element of "
"`cross_val_targets` has second dimension greater than 1."
)
if mu_test.ndim == 1:
cross_test_outputs[i] = mu_test[:, None]
elif mu_test.shape[1] != 1:
raise ValueError(
"This method is supported only for scalar model outputs only. However, an element of "
"`cross_test_outputs` has second dimension greater than 1."
)
r = [jnp.abs(y - mu) for y, mu in zip(cross_val_targets, cross_val_outputs)]
left = jnp.concatenate(
[mu[None] - ri[:, None] for mu, ri in zip(cross_test_outputs, r)], 0
)
right = jnp.concatenate(
[mu[None] + ri[:, None] for mu, ri in zip(cross_test_outputs, r)], 0
)
qleft = jnp.quantile(left, q=error, axis=0)
qright = jnp.quantile(right, q=1 - error, axis=0)
return jnp.array(list(zip(qleft, qright))).squeeze(2)