Posterior fitting configuration#

This section describes FitConfig, an object that configures the posterior fitting process. It is made of several objects:

  • FitOptimizer: to configure the optimization process;

  • FitCheckpointer: to save and restore checkpoints;

  • FitMonitor: to monitor the process and trigger early stopping;

  • FitProcessor: to decide how and where the computation is processed.

  • List[Callback]: to allow the user to perform custom actions at different stages of the training process.

class fortuna.prob_model.fit_config.base.FitConfig(optimizer=FitOptimizer(), checkpointer=FitCheckpointer(), monitor=FitMonitor(), processor=FitProcessor(), hyperparameters=FitHyperparameters(), callbacks=None)[source]#

Configure the posterior distribution fitting.

Parameters:
  • optimizer (FitOptimizer) – It defines the optimization specifics.

  • checkpointer (FitCheckpointer) – It handles saving and restoring checkpoints.

  • monitor (FitMonitor) – It monitors training progress and might induce early stopping.

  • processor (FitProcessor) – It processes where computation takes place.

  • hyperparameters (FitHyperparameters) – It defines other hyperparameters that may be needed during model’s training.

  • callbacks (Optional[List[FitCallback]]) – A list of user-defined callbacks to be called during training. Callbacks run sequentially in the order defined by the user.

class fortuna.prob_model.fit_config.optimizer.FitOptimizer(method=optax.adam(1e-3), n_epochs=100, freeze_fun=None)[source]#

An object to configure the optimization in the posterior fitting.

Parameters:
  • method (OptaxOptimizer) – An Optax optimizer.

  • n_epochs (int) – Maximum number of epochs to run the training for.

  • freeze_fun (Optional[Callable[[Tuple[AnyKey, ...], Array], str]]) – A callable taking in input a path in the nested dictionary of parameters, as well as the corresponding array of parameters, and returns “trainable” or “freeze”, according to whether the corresponding parameter should be optimized or not.

class fortuna.prob_model.fit_config.monitor.FitMonitor(metrics=None, early_stopping_patience=0, early_stopping_monitor='val_loss', early_stopping_min_delta=0.0, eval_every_n_epochs=1, disable_training_metrics_computation=False, verbose=True)[source]#

An object to configure the monitoring of the posterior fitting.

Parameters:
  • metrics (Optional[Callable[[jnp.ndarray, Array], Union[float, Array]]]) – Metrics to monitor during training.

  • early_stopping_patience (int) –

    Number of consecutive epochs without an improvement in the performance on the validation set before stopping

    the training.

  • early_stopping_monitor (str) – Validation metric to be monitored for early stopping.

  • early_stopping_min_delta (float) – Minimum change between updates to be considered an improvement, i.e., if the absolute change is less than early_stopping_min_delta then this is not considered an improvement leading to a potential early stop.

  • eval_every_n_epochs (int) – Number of training epochs between validation. To disable, set eval_every_n_epochs to None or 0 (i.e., no validation metrics will be computed during training).

  • disable_training_metrics_computation (bool) – if True, during training the only metric computed is the objective function. Otherwise, all the metrics provided by the user at runtime will be computed for the training step.

  • verbose (bool) – Whether to log the training progress.

class fortuna.prob_model.fit_config.checkpointer.FitCheckpointer(save_checkpoint_dir=None, restore_checkpoint_path=None, start_from_current_state=False, save_every_n_steps=None, keep_top_n_checkpoints=2, dump_state=False)[source]#

An object to configure saving and restoring of checkpoints during the posterior fitting.

Parameters:
  • save_checkpoint_dir (Optional[Path]) – Save directory location.

  • restore_checkpoint_path (Optional[Path]) – Path to checkpoint file or directory to restore.

  • start_from_current_state (bool = False) – If True, the optimization will start from the current state. If restore_checkpoint_path is given, then start_from_current_state is ignored.

  • save_every_n_steps (int) – Number of training steps between checkpoints. To disable, set every_n_train_steps to None or 0 (no checkpoint will be saved during training).

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

  • dump_state (bool) – Dump the fitted posterior state as a checkpoint in save_checkpoint_dir. Any future call to the state will internally involve restoring it from memory.

class fortuna.prob_model.fit_config.processor.FitProcessor(devices=-1, disable_jit=False)[source]#

An object to configure computational aspects of the posterior fitting.

Parameters:
  • devices (int) – A list of devices to be used during training. At the moment two options are supported: use all devices (devices=-1) or use no device (devices=0).

  • disable_jit (bool) – if True, no function within the training loop is jitted.

class fortuna.prob_model.fit_config.callback.Callback[source]#

Base class to define new callback functions. To define a new callback, create a child of this class and override the relevant methods.

Example

The following is a custom callback that prints the number of model’s parameters at the start of each epoch.

class CountParamsCallback(Callback):
    def training_epoch_start(self, state: TrainState) -> TrainState:
        params, unravel = ravel_pytree(state.params)
        logger.info(f"num params: {len(params)}")
        return state
training_epoch_end(state)[source]#

Called at the end of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_epoch_start(state)[source]#

Called at the beginning of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_step_end(state)[source]#

Called after every minibatch update

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState