Source code for fortuna.training.callback

from fortuna.training.train_state import TrainState


[docs]class Callback: """ Base class to define new callback functions. To define a new callback, create a child of this class and override the relevant methods. Example ------- The following is a custom callback that prints the number of model's parameters at the start of each epoch. .. code-block:: python class CountParamsCallback(Callback): def training_epoch_start(self, state: TrainState) -> TrainState: params, unravel = ravel_pytree(state.params) logger.info(f"num params: {len(params)}") return state """
[docs] def training_epoch_start(self, state: TrainState) -> TrainState: """ Called at the beginning of every training epoch Parameters ---------- state: TrainState The training state Returns ------- TrainState The (possibly updated) training state """ return state
[docs] def training_epoch_end(self, state: TrainState) -> TrainState: """ Called at the end of every training epoch Parameters ---------- state: TrainState The training state Returns ------- TrainState The (possibly updated) training state """ return state
[docs] def training_step_end(self, state: TrainState) -> TrainState: """ Called after every minibatch update Parameters ---------- state: TrainState The training state Returns ------- TrainState The (possibly updated) training state """ return state