Calibration configuration#

This section describes CalibConfig, an object that configures the calibration process of the probabilistic model. It is made of several objects:

class fortuna.prob_model.calib_config.base.CalibConfig(optimizer=CalibOptimizer(), checkpointer=CalibCheckpointer(), monitor=CalibMonitor(), processor=CalibProcessor())[source]#

Configure the probabilistic model calibration.

Parameters:
  • optimizer (CalibOptimizer) – It defines the optimization specifics.

  • checkpointer (CalibCheckpointer) – It handles saving and restoring checkpoints.

  • monitor (CalibMonitor) – It monitors training progress and might induce early stopping.

  • processor (CalibProcessor) – It processes where computation takes place.

class fortuna.prob_model.calib_config.optimizer.CalibOptimizer(method=optax.adam(1e-2), n_epochs=100)[source]#

An object to configure the optimization in the calibration process.

Parameters:
  • method (OptaxOptimizer) – An Optax optimizer.

  • n_epochs (int) – Maximum number of epochs to run the calibration for.

class fortuna.prob_model.calib_config.checkpointer.CalibCheckpointer(save_checkpoint_dir=None, restore_checkpoint_path=None, save_every_n_steps=None, keep_top_n_checkpoints=2, dump_state=False)[source]#

An object to configure saving and restoring of checkpoints during the calibration process.

Parameters:
  • save_checkpoint_dir (Optional[Path] = None) – Save directory location.

  • restore_checkpoint_path (Optional[Path]) – Path to checkpoint file or directory to restore.

  • save_every_n_steps (int) – Number of training steps between checkpoints. To disable, set every_n_train_steps to None or 0 (no checkpoint will be saved during training).

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

  • dump_state (bool) – Dump the fitted calibration state as a checkpoint in save_checkpoint_dir. Any future call to the state will internally involve restoring it from memory.

class fortuna.prob_model.calib_config.monitor.CalibMonitor(metrics=None, uncertainty_fn=None, early_stopping_patience=0, early_stopping_monitor='val_loss', early_stopping_min_delta=0.0, eval_every_n_epochs=1, disable_calibration_metrics_computation=False, verbose=True)[source]#

An object to configure the monitoring of the calibration process.

Parameters:
  • metrics (Optional[Callable[[jnp.ndarray, jnp.ndarray, Array], Union[float, Array]]]) – Metrics to monitor during calibration. This must take three arguments: predictions, uncertainty estimates and target variables. In classification, expected_calibration_error() is an example of valid metric.

  • uncertainty_fn (Optional[Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Union[float, Array]], ...]]) – A function that maps (calibrated) outputs into uncertainty estimates. These will be used in metrics. In classification, the default is mean(). In regression, the default is variance().

  • early_stopping_patience (int) –

    Number of consecutive epochs without an improvement in the performance on the validation set before stopping

    the calibration.

  • early_stopping_monitor (str) – Validation metric to be monitored for early stopping.

  • early_stopping_min_delta (float) – Minimum change between updates to be considered an improvement, i.e., if the absolute change is less than early_stopping_min_delta then this is not considered an improvement leading to a potential early stop.

  • eval_every_n_epochs (int) – Number of calibration epochs between validation. To disable, set eval_every_n_epochs to None or 0 (i.e., no validation metrics will be computed during calibration).

  • disable_calibration_metrics_computation (bool) – if True, during calibration the only metric computed is the objective function. Otherwise, all the metrics provided by the user at runtime will be computed for the training step.

  • verbose (bool) – Whether to log the calibration progress.

class fortuna.prob_model.calib_config.processor.CalibProcessor(devices=-1, disable_jit=False, n_posterior_samples=30)[source]#

An object to configure computational aspects of the calibration process.

Parameters:
  • devices (int) – A list of devices to be used during training. At the moment two options are supported: use all devices (devices=-1) or use no device (devices=0).

  • disable_jit (bool) – if True, no function within the calibration loop is jitted.

  • n_posterior_samples (int) – Number of posterior samples to draw from the posterior distribution for the calibration process.