Source code for fortuna.output_calib_model.predictive.regression

from typing import (
    List,
    Optional,
    Union,
)

from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

from fortuna.output_calib_model.predictive.base import Predictive
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_output_layer.regression import RegressionProbOutputLayer
from fortuna.typing import Array


[docs]class RegressionPredictive(Predictive): def __init__( self, output_calib_manager: OutputCalibManager, prob_output_layer: RegressionProbOutputLayer, ): super().__init__( output_calib_manager=output_calib_manager, prob_output_layer=prob_output_layer, )
[docs] def quantile( self, q: Union[float, Array, List], outputs: Array, n_samples: Optional[int] = 30, rng: Optional[PRNGKeyArray] = None, calibrated: bool = True, ) -> jnp.ndarray: """ Estimate the quantile of the target variable given the output, with respect to the predictive distribution. Parameters ---------- q: Union[float, Array, List] Quantile(s) to estimate. outputs : jnp.ndarray Model outputs. n_samples: Optional[int] Number of target samples to draw when computing quantiles. rng: Optional[PRNGKeyArray] A random number generator. calibrated : bool Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the model must have been calibrated beforehand. Returns ------- jnp.ndarray The estimated quantiles for each output. """ if calibrated: self._check_calibrated() state = self.state.get() outputs = self.output_calib_manager.apply( params=state.params["output_calibrator"], outputs=outputs, mutable=state.mutable["output_calibrator"], ) return self.prob_output_layer.quantile(q, outputs, n_samples, rng)
[docs] def credible_interval( self, outputs: Array, n_samples: int = 30, error: float = 0.05, interval_type: str = "two-tailed", rng: Optional[PRNGKeyArray] = None, calibrated: bool = True, ) -> jnp.ndarray: """ Estimate a credible interval of the target variable given the output, with respect to the predictive distribution. Parameters ---------- outputs: Array Model outputs. n_samples: int Number of target samples to draw for each output. error: float The interval error. This must be a number between 0 and 1, extremes included. For example, `error=0.05` corresponds to a 95% level of credibility. interval_type: str The interval type. We support "two-tailed" (default), "right-tailed" and "left-tailed". rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. calibrated : bool Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the model must have been calibrated beforehand. Returns ------- jnp.ndarray The estimated credible interval for each output. """ if calibrated: self._check_calibrated() state = self.state.get() outputs = self.output_calib_manager.apply( params=state.params["output_calibrator"], outputs=outputs, mutable=state.mutable["output_calibrator"], ) return self.prob_output_layer.credible_interval( outputs, n_samples, error, interval_type, rng )