Output calibration model#

We support a calibration classifier for classification and a calibration regressor for regression. Please find their references below.

class fortuna.output_calib_model.classification.OutputCalibClassifier(output_calibrator=ClassificationTemperatureScaler(), seed=0)[source]#

A calibration classifier class.

Parameters:
  • output_calibrator (Optional[nn.Module]) – An output calibrator object. The default is temperature scaling for classification, which rescales the logits with a scalar temperature parameter. Given outputs \(o\), the output calibrator is described by a function \(g(\phi, o)\), where phi are calibration parameters.

  • seed (int) – A random seed.

output_calibrator#

See output_calibrator in Parameters.

Type:

nn.Module

output_calib_manager#

It manages the forward pass of the output calibrator.

Type:

OutputCalibManager

prob_output_layer#

A probabilistic output payer. It characterizes the distribution of the target variables given the outputs.

Type:

ClassificationProbOutputLayer

predictive#

The predictive distribution.

Type:

ClassificationPredictive

calibrate(calib_outputs, calib_targets, val_outputs=None, val_targets=None, loss_fn=focal_loss_fn, config=Config())[source]#

Calibrate the model outputs.

Parameters:
  • calib_outputs (Array) – Calibration model outputs.

  • calib_targets (Array) – Calibration target variables.

  • val_outputs (Optional[Array]) – Validation model outputs.

  • val_targets (Optional[Array]) – Validation target variables.

  • loss_fn (Callable[[Outputs, Targets], jnp.ndarray]) – The loss function to use for calibration.

  • config (Config) – An object to configure the calibration.

Returns:

A calibration status object. It provides information about the calibration.

Return type:

Status

load_state(checkpoint_path)#

Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model.

Parameters:

checkpoint_path (Path) – Path to a checkpoint file or directory to restore.

Return type:

None

save_state(checkpoint_path, keep_top_n_checkpoints=1)#

Save the calibration state as a checkpoint.

Parameters:
  • checkpoint_path (Path) – Path to file or directory where to save the current state.

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

Return type:

None

class fortuna.output_calib_model.regression.OutputCalibRegressor(output_calibrator=RegressionTemperatureScaler(), seed=0)[source]#

A calibration regressor class.

Parameters:
  • output_calibrator (Optional[nn.Module]) – An output calibrator object. The default is temperature scaling for regression, which inflates the variance of the likelihood with a scalar temperature parameter. Given outputs \(o\) of the model manager, the output calibrator is described by a function \(g(\phi, o)\), where phi are calibration parameters.

  • seed (int) – A random seed.

output_calibrator#

See output_calibrator in Parameters.

Type:

nn.Module

output_calib_manager#

It manages the forward pass of the output calibrator.

Type:

OutputCalibManager

prob_output_layer#

A probabilistic output payer. It characterizes the distribution of the target variables given the outputs.

Type:

RegressionProbOutputLayer

predictive#

The predictive distribution.

Type:

RegressionPredictive

calibrate(calib_outputs, calib_targets, val_outputs=None, val_targets=None, loss_fn=scaled_mse_fn, config=Config())[source]#

Calibrate the model outputs.

Parameters:
  • calib_outputs (Array) – Calibration model outputs.

  • calib_targets (Array) – Calibration target variables.

  • val_outputs (Optional[Array]) – Validation model outputs.

  • val_targets (Optional[Array]) – Validation target variables.

  • loss_fn (Callable[[Outputs, Targets], jnp.ndarray]) – The loss function to use for calibration.

  • config (Config) – An object to configure the calibration.

Returns:

A calibration status object. It provides information about the calibration.

Return type:

Status

load_state(checkpoint_path)#

Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model.

Parameters:

checkpoint_path (Path) – Path to a checkpoint file or directory to restore.

Return type:

None

save_state(checkpoint_path, keep_top_n_checkpoints=1)#

Save the calibration state as a checkpoint.

Parameters:
  • checkpoint_path (Path) – Path to file or directory where to save the current state.

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

Return type:

None

class fortuna.output_calib_model.base.OutputCalibModel(seed=0)[source]#

Abstract calibration model class.

Mixin class for all trainers that need checkpointing capabilities. This is a wrapper around functions in flax.training.checkpoints.*.

load_state(checkpoint_path)[source]#

Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model.

Parameters:

checkpoint_path (Path) – Path to a checkpoint file or directory to restore.

Return type:

None

save_state(checkpoint_path, keep_top_n_checkpoints=1)[source]#

Save the calibration state as a checkpoint.

Parameters:
  • checkpoint_path (Path) – Path to file or directory where to save the current state.

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

Return type:

None