Training Callbacks#

This section describes FitCallback, which allows users to add custom actions at different stages of the training loop. Callbacks can be used while training a ProbModel.

To use callbacks the user has to:

  • Define their own callbacks by subclassing FitCallback and override the methods of interest.

  • When calling the train method of a ProbModel instance, add a list of callbacks to the configuration object FitConfig.

The following example outlines the usage of FitCallback. It assumes that the user already obtained an instance of ProbModel:

from jax.flatten_util import ravel_pytree
import optax

from fortuna.training.train_state import TrainState
from fortuna.prob_model.fit_config import FitConfig, FitMonitor, FitOptimizer, FitCallback
from fortuna.metric.classification import accuracy

# Define custom callback
class CountParamsCallback(FitCallback):
    def training_epoch_start(self, state: TrainState) -> TrainState:
        params, unravel = ravel_pytree(state.params)
        logger.info(f"num params: {len(params)}")
        return state

# Add a list of callbacks containing CountParamsCallback to FitConfig
status = prob_model.train(
    train_data_loader=train_data_loader,
    val_data_loader=val_data_loader,
    calib_data_loader=val_data_loader,
    fit_config=FitConfig(
        optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=100),
        callbacks=[
            CountParamsCallback()
        ]
    )
)
class fortuna.training.callback.Callback[source]#

Base class to define new callback functions. To define a new callback, create a child of this class and override the relevant methods.

Example

The following is a custom callback that prints the number of model’s parameters at the start of each epoch.

class CountParamsCallback(Callback):
    def training_epoch_start(self, state: TrainState) -> TrainState:
        params, unravel = ravel_pytree(state.params)
        logger.info(f"num params: {len(params)}")
        return state
training_epoch_end(state)[source]#

Called at the end of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_epoch_start(state)[source]#

Called at the beginning of every training epoch

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState

training_step_end(state)[source]#

Called after every minibatch update

Parameters:

state (TrainState) – The training state

Returns:

The (possibly updated) training state

Return type:

TrainState