Training Callbacks¶
This section describes FitCallback
,
which allows users to add custom actions at different stages of the training loop.
Callbacks can be used while training a ProbModel
.
To use callbacks the user has to:
Define their own callbacks by subclassing
FitCallback
and override the methods of interest.When calling the train method of a
ProbModel
instance, add a list of callbacks to the configuration objectFitConfig
.
The following example outlines the usage of FitCallback
.
It assumes that the user already obtained an instance of ProbModel
:
from jax.flatten_util import ravel_pytree
import optax
from fortuna.training.train_state import TrainState
from fortuna.prob_model.fit_config import FitConfig, FitMonitor, FitOptimizer, FitCallback
from fortuna.metric.classification import accuracy
# Define custom callback
class CountParamsCallback(FitCallback):
def training_epoch_start(self, state: TrainState) -> TrainState:
params, unravel = ravel_pytree(state.params)
logger.info(f"num params: {len(params)}")
return state
# Add a list of callbacks containing CountParamsCallback to FitConfig
status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=100),
callbacks=[
CountParamsCallback()
]
)
)
- class fortuna.training.callback.Callback[source]¶
Base class to define new callback functions. To define a new callback, create a child of this class and override the relevant methods.
Example
The following is a custom callback that prints the number of model’s parameters at the start of each epoch.
class CountParamsCallback(Callback): def training_epoch_start(self, state: TrainState) -> TrainState: params, unravel = ravel_pytree(state.params) logger.info(f"num params: {len(params)}") return state
- training_epoch_end(state)[source]¶
Called at the end of every training epoch
- Parameters:
state (TrainState) – The training state
- Returns:
The (possibly updated) training state
- Return type:
TrainState