Output calibration model#
We support a calibration classifier for classification and a calibration regressor for regression. Please find their references below.
- class fortuna.output_calib_model.classification.OutputCalibClassifier(output_calibrator=ClassificationTemperatureScaler(), seed=0)[source]#
A calibration classifier class.
- Parameters:
output_calibrator (Optional[nn.Module]) – An output calibrator object. The default is temperature scaling for classification, which rescales the logits with a scalar temperature parameter. Given outputs \(o\), the output calibrator is described by a function \(g(\phi, o)\), where phi are calibration parameters.
seed (int) – A random seed.
- output_calibrator#
See output_calibrator in Parameters.
- Type:
nn.Module
- output_calib_manager#
It manages the forward pass of the output calibrator.
- Type:
OutputCalibManager
- prob_output_layer#
A probabilistic output payer. It characterizes the distribution of the target variables given the outputs.
- predictive#
The predictive distribution.
- Type:
- calibrate(calib_outputs, calib_targets, val_outputs=None, val_targets=None, loss_fn=focal_loss_fn, config=Config())[source]#
Calibrate the model outputs.
- Parameters:
calib_outputs (Array) – Calibration model outputs.
calib_targets (Array) – Calibration target variables.
val_outputs (Optional[Array]) – Validation model outputs.
val_targets (Optional[Array]) – Validation target variables.
loss_fn (Callable[[Outputs, Targets], jnp.ndarray]) – The loss function to use for calibration.
config (Config) – An object to configure the calibration.
- Returns:
A calibration status object. It provides information about the calibration.
- Return type:
Status
- load_state(checkpoint_path)#
Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model.
- Parameters:
checkpoint_path (Path) – Path to a checkpoint file or directory to restore.
- Return type:
None
- save_state(checkpoint_path, keep_top_n_checkpoints=1)#
Save the calibration state as a checkpoint.
- Parameters:
checkpoint_path (Path) – Path to file or directory where to save the current state.
keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.
- Return type:
None
- class fortuna.output_calib_model.regression.OutputCalibRegressor(output_calibrator=RegressionTemperatureScaler(), seed=0)[source]#
A calibration regressor class.
- Parameters:
output_calibrator (Optional[nn.Module]) – An output calibrator object. The default is temperature scaling for regression, which inflates the variance of the likelihood with a scalar temperature parameter. Given outputs \(o\) of the model manager, the output calibrator is described by a function \(g(\phi, o)\), where phi are calibration parameters.
seed (int) – A random seed.
- output_calibrator#
See output_calibrator in Parameters.
- Type:
nn.Module
- output_calib_manager#
It manages the forward pass of the output calibrator.
- Type:
OutputCalibManager
- prob_output_layer#
A probabilistic output payer. It characterizes the distribution of the target variables given the outputs.
- predictive#
The predictive distribution.
- Type:
- calibrate(calib_outputs, calib_targets, val_outputs=None, val_targets=None, loss_fn=scaled_mse_fn, config=Config())[source]#
Calibrate the model outputs.
- Parameters:
calib_outputs (Array) – Calibration model outputs.
calib_targets (Array) – Calibration target variables.
val_outputs (Optional[Array]) – Validation model outputs.
val_targets (Optional[Array]) – Validation target variables.
loss_fn (Callable[[Outputs, Targets], jnp.ndarray]) – The loss function to use for calibration.
config (Config) – An object to configure the calibration.
- Returns:
A calibration status object. It provides information about the calibration.
- Return type:
Status
- load_state(checkpoint_path)#
Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model.
- Parameters:
checkpoint_path (Path) – Path to a checkpoint file or directory to restore.
- Return type:
None
- save_state(checkpoint_path, keep_top_n_checkpoints=1)#
Save the calibration state as a checkpoint.
- Parameters:
checkpoint_path (Path) – Path to file or directory where to save the current state.
keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.
- Return type:
None
- class fortuna.output_calib_model.base.OutputCalibModel(seed=0)[source]#
Abstract calibration model class.
Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in flax.training.checkpoints.*.