Calibration configuration¶
This section describes Config
,
an object that configures the calibration process of the probabilistic model. It is made of several objects:
Optimizer
: to configure the optimization process;Checkpointer
: to save and restore checkpoints;Monitor
: to monitor the process and trigger early stopping;Processor
: to decide how and where the computation is processed.
- class fortuna.output_calib_model.config.base.Config(optimizer=Optimizer(), checkpointer=Checkpointer(), monitor=Monitor(), processor=Processor())[source]¶
Configure the calibration of the output calibration model.
- Parameters:
optimizer (Optimizer) – It defines the optimization specifics.
checkpointer (Checkpointer) – It handles saving and restoring checkpoints.
monitor (Monitor) – It monitors training progress and might induce early stopping.
processor (Processor) – It processes where computation takes place.
- class fortuna.output_calib_model.config.optimizer.Optimizer(method=optax.adam(1e-2), n_epochs=100)[source]¶
An object to configure the optimization in the calibration process.
- Parameters:
method (OptaxOptimizer) – An Optax optimizer.
n_epochs (int) – Maximum number of epochs to run the calibration for.
- class fortuna.output_calib_model.config.checkpointer.Checkpointer(save_checkpoint_dir=None, restore_checkpoint_path=None, save_every_n_steps=None, keep_top_n_checkpoints=2, dump_state=False)[source]¶
An object to configure saving and restoring of checkpoints during the calibration process.
- Parameters:
save_checkpoint_dir (Optional[Path] = None) – Save directory location.
restore_checkpoint_path (Optional[Path]) – Path to checkpoint file or directory to restore.
save_every_n_steps (int) – Number of training steps between checkpoints. To disable, set every_n_train_steps to None or 0 (no checkpoint will be saved during training).
keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.
dump_state (bool) – Dump the fitted calibration state as a checkpoint in save_checkpoint_dir. Any future call to the state will internally involve restoring it from memory.
- class fortuna.output_calib_model.config.monitor.Monitor(metrics=None, uncertainty_fn=None, early_stopping_patience=0, early_stopping_monitor='val_loss', early_stopping_min_delta=0.0, eval_every_n_epochs=1, disable_calibration_metrics_computation=False, verbose=True)[source]¶
An object to configure the monitoring of the calibration process.
- Parameters:
metrics (Optional[Callable[[jnp.ndarray, jnp.ndarray, Array], Union[float, Array]]]) – Metrics to monitor during calibration. This must take three arguments: predictions, uncertainty estimates and target variables. In classification,
expected_calibration_error()
is an example of valid metric.uncertainty_fn (Optional[Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Union[float, Array]], ...]]) – A function that maps (calibrated) outputs into uncertainty estimates. These will be used in metrics. In classification, the default is
mean()
. In regression, the default isvariance()
.early_stopping_patience (int) –
- Number of consecutive epochs without an improvement in the performance on the validation set before stopping
the calibration.
early_stopping_monitor (str) – Validation metric to be monitored for early stopping.
early_stopping_min_delta (float) – Minimum change between updates to be considered an improvement, i.e., if the absolute change is less than early_stopping_min_delta then this is not considered an improvement leading to a potential early stop.
eval_every_n_epochs (int) – Number of calibration epochs between validation. To disable, set eval_every_n_epochs to None or 0 (i.e., no validation metrics will be computed during calibration).
disable_calibration_metrics_computation (bool) – if True, during calibration the only metric computed is the objective function. Otherwise, all the metrics provided by the user at runtime will be computed for the training step.
verbose (bool) – Whether to log the calibration progress.
- class fortuna.output_calib_model.config.processor.Processor(devices=-1, disable_jit=False)[source]¶
An object to configure computational aspects of the calibration process.
- Parameters:
devices (int) – A list of devices to be used during training. At the moment two options are supported: use all devices (devices=-1) or use no device (devices=0).
disable_jit (bool) – if True, no function within the calibration loop is jitted.