Source code for fortuna.output_calib_model.regression
from typing import (
Callable,
Optional,
)
import flax.linen as nn
import jax.numpy as jnp
from fortuna.loss.regression.scaled_mse import scaled_mse_fn
from fortuna.output_calib_model.base import OutputCalibModel
from fortuna.output_calib_model.config.base import Config
from fortuna.output_calib_model.predictive.regression import RegressionPredictive
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.output_calibrator.regression import RegressionTemperatureScaler
from fortuna.prob_output_layer.regression import RegressionProbOutputLayer
from fortuna.typing import (
Array,
Outputs,
Status,
Targets,
)
[docs]class OutputCalibRegressor(OutputCalibModel):
def __init__(
self,
output_calibrator: Optional[nn.Module] = RegressionTemperatureScaler(),
seed: int = 0,
) -> None:
r"""
A calibration regressor class.
Parameters
----------
output_calibrator : Optional[nn.Module]
An output calibrator object. The default is temperature scaling for regression, which inflates the variance
of the likelihood with a scalar temperature parameter. Given outputs :math:`o` of the model manager, the
output calibrator is described by a function :math:`g(\phi, o)`, where `phi` are
calibration parameters.
seed: int
A random seed.
Attributes
----------
output_calibrator : nn.Module
See `output_calibrator` in `Parameters`.
output_calib_manager : OutputCalibManager
It manages the forward pass of the output calibrator.
prob_output_layer : RegressionProbOutputLayer
A probabilistic output payer.
It characterizes the distribution of the target variables given the outputs.
predictive : RegressionPredictive
The predictive distribution.
"""
self.output_calibrator = output_calibrator
self.output_calib_manager = OutputCalibManager(
output_calibrator=output_calibrator
)
self.prob_output_layer = RegressionProbOutputLayer()
self.predictive = RegressionPredictive(
output_calib_manager=self.output_calib_manager,
prob_output_layer=self.prob_output_layer,
)
super().__init__(seed=seed)
[docs] def calibrate(
self,
calib_outputs: Array,
calib_targets: Array,
val_outputs: Optional[Array] = None,
val_targets: Optional[Array] = None,
loss_fn: Callable[[Outputs, Targets], jnp.ndarray] = scaled_mse_fn,
config: Config = Config(),
) -> Status:
"""
Calibrate the model outputs.
Parameters
----------
calib_outputs: Array
Calibration model outputs.
calib_targets: Array
Calibration target variables.
val_outputs: Optional[Array]
Validation model outputs.
val_targets: Optional[Array]
Validation target variables.
loss_fn: Callable[[Outputs, Targets], jnp.ndarray]
The loss function to use for calibration.
config : Config
An object to configure the calibration.
Returns
-------
Status
A calibration status object. It provides information about the calibration.
"""
self._check_output_dim(calib_outputs, calib_targets)
if val_outputs is not None:
self._check_output_dim(val_outputs, val_targets)
return super()._calibrate(
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.variance
),
calib_outputs=calib_outputs,
calib_targets=calib_targets,
val_outputs=val_outputs,
val_targets=val_targets,
loss_fn=loss_fn,
config=config,
)
@staticmethod
def _check_output_dim(outputs: jnp.ndarray, targets: jnp.array):
if outputs.shape[1] != 2 * targets.shape[1]:
raise ValueError(
f"""`outputs.shape[1]` must be twice the dimension of the target variables in `targets`, with
first and second halves corresponding to the mean and log-variance of the likelihood, respectively.
However, `outputs.shape[1]={outputs.shape[1]}` and `targets.shape[1]={targets.shape[1]}`."""
)