Source code for fortuna.output_calib_model.predictive.classification
from typing import Optional
import jax.numpy as jnp
from fortuna.output_calib_model.predictive.base import Predictive
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_output_layer.classification import ClassificationProbOutputLayer
from fortuna.typing import Array
[docs]class ClassificationPredictive(Predictive):
def __init__(
self,
output_calib_manager: OutputCalibManager,
prob_output_layer: ClassificationProbOutputLayer,
):
super().__init__(
output_calib_manager=output_calib_manager,
prob_output_layer=prob_output_layer,
)
[docs] def mean(self, outputs: Array, calibrated: bool = True, **kwargs) -> jnp.ndarray:
"""
Estimate the mean of the one-hot encoded target variable given the output, with respect to the predictive
distribution.
Parameters
----------
outputs : Array
Model outputs.
calibrated : bool
Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the
model must have been calibrated beforehand.
Returns
-------
jnp.ndarray
The estimated mean for each output.
"""
return super().mean(outputs, calibrated, **kwargs)
[docs] def mode(self, outputs: Array, calibrated: bool = True, **kwargs) -> jnp.ndarray:
"""
Estimate the mode of the one-hot encoded target variable given the output, with respect to the predictive
distribution.
Parameters
----------
outputs : Array
Model outputs.
calibrated : bool
Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the
model must have been calibrated beforehand.
Returns
-------
jnp.ndarray
The estimated mode for each output.
"""
return super().mode(outputs, calibrated, **kwargs)
[docs] def variance(
self, outputs: jnp.ndarray, calibrated: bool = True, **kwargs
) -> jnp.ndarray:
"""
Estimate the variance of the one-hot encoded target variable given the output, with respect to the predictive
distribution.
Parameters
----------
outputs : jnp.ndarray
Model outputs.
calibrated : bool
Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the
model must have been calibrated beforehand.
Returns
-------
jnp.ndarray
The estimated variance for each output.
"""
return super().variance(outputs, calibrated, **kwargs)
[docs] def std(
self,
outputs: jnp.ndarray,
variances: Optional[jnp.ndarray] = None,
calibrated: bool = True,
) -> jnp.ndarray:
"""
Estimate the standard deviation of the one-hot encoded target variable given the output, with respect to the
predictive distribution.
Parameters
----------
outputs : jnp.ndarray
Model outputs.
variances: Optional[jnp.ndarray]
Variance for each output.
calibrated : bool
Whether the outputs should be calibrated when computing this method. If `calibrated` is set to True, the
model must have been calibrated beforehand.
Returns
-------
jnp.ndarray
The estimated standard deviation for each output.
"""
return super().std(outputs, variances, calibrated)