Source code for fortuna.output_calib_model.base

import abc
import logging
from typing import (
    Callable,
    Optional,
)

from flax.core import FrozenDict
import jax.numpy as jnp

from fortuna.output_calib_model.config.base import Config
from fortuna.output_calib_model.loss import Loss
from fortuna.output_calib_model.output_calib_mixin import (
    WithOutputCalibCheckpointingMixin,
)
from fortuna.output_calib_model.output_calib_model_calibrator import (
    JittedOutputCalibModelCalibrator,
    MultiDeviceOutputCalibModelCalibrator,
    OutputCalibModelCalibrator,
)
from fortuna.output_calib_model.output_calib_state_repository import (
    OutputCalibStateRepository,
)
from fortuna.output_calib_model.state import OutputCalibState
from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState
from fortuna.typing import (
    Array,
    Outputs,
    Path,
    Status,
    Targets,
)
from fortuna.utils.device import select_trainer_given_devices
from fortuna.utils.random import RandomNumberGenerator


[docs]class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC): """ Abstract calibration model class. """ def __init__(self, seed: int = 0): super().__init__() self.rng = RandomNumberGenerator(seed=seed) self.__set_rng() def __set_rng(self): self.output_calib_manager.rng = self.rng self.prob_output_layer.rng = self.rng self.predictive.rng = self.rng def _calibrate( self, uncertainty_fn: Callable[[jnp.ndarray, jnp.ndarray, Array], jnp.ndarray], loss_fn: Callable[[Outputs, Targets], jnp.ndarray], calib_outputs: Array, calib_targets: Array, val_outputs: Optional[Array] = None, val_targets: Optional[Array] = None, config: Config = Config(), ) -> Status: if (val_targets is not None and val_outputs is None) or ( val_targets is None and val_outputs is not None ): raise ValueError( "For validation, both `val_outputs` and `val_targets` must be passed as arguments." ) trainer_cls = select_trainer_given_devices( devices=config.processor.devices, base_trainer_cls=OutputCalibModelCalibrator, jitted_trainer_cls=JittedOutputCalibModelCalibrator, multi_device_trainer_cls=MultiDeviceOutputCalibModelCalibrator, disable_jit=config.processor.disable_jit, ) calibrator = trainer_cls( calib_outputs=calib_outputs, calib_targets=calib_targets, val_outputs=val_outputs, val_targets=val_targets, predict_fn=self.prob_output_layer.predict, uncertainty_fn=uncertainty_fn, save_checkpoint_dir=config.checkpointer.save_checkpoint_dir, save_every_n_steps=config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints, disable_training_metrics_computation=config.monitor.disable_calibration_metrics_computation, eval_every_n_epochs=config.monitor.eval_every_n_epochs, early_stopping_monitor=config.monitor.early_stopping_monitor, early_stopping_min_delta=config.monitor.early_stopping_min_delta, early_stopping_patience=config.monitor.early_stopping_patience, ) if config.checkpointer.restore_checkpoint_path is None: state = OutputCalibManagerState.init_from_dict( d=FrozenDict( output_calibrator=self.output_calib_manager.init( output_dim=calib_outputs.shape[-1] ) ), ) state = OutputCalibState.init( params=state.params, mutable=state.mutable, optimizer=config.optimizer.method, ) else: state = self.restore_checkpoint( config.checkpointer.restore_checkpoint_path, optimizer=config.optimizer.method, ) loss = Loss(self.predictive, loss_fn=loss_fn) loss.rng = self.rng if config.monitor.verbose: logging.info("Start calibration.") state, status = calibrator.train( rng=self.rng.get(), state=state, loss_fun=loss, n_epochs=config.optimizer.n_epochs, metrics=config.monitor.metrics, verbose=config.monitor.verbose, ) self.predictive.state = OutputCalibStateRepository( config.checkpointer.save_checkpoint_dir if config.checkpointer.dump_state is True else None ) self.predictive.state.put( state, keep=config.checkpointer.keep_top_n_checkpoints ) return status
[docs] def load_state(self, checkpoint_path: Path) -> None: """ Load a calibration state from a checkpoint path. The checkpoint must be compatible with the calibration model. Parameters ---------- checkpoint_path : Path Path to a checkpoint file or directory to restore. """ try: self.restore_checkpoint(checkpoint_path) except ValueError: raise ValueError( f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." ) self.predictive.state = OutputCalibStateRepository( checkpoint_dir=checkpoint_path )
[docs] def save_state( self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1 ) -> None: """ Save the calibration state as a checkpoint. Parameters ---------- checkpoint_path : Path Path to file or directory where to save the current state. keep_top_n_checkpoints : int Number of past checkpoint files to keep. """ if self.predictive.state is None: raise ValueError( """No state available. You must first either calibrate the model, or load a saved checkpoint.""" ) return self.predictive.state.put( self.predictive.state.get(), checkpoint_path=checkpoint_path, keep=keep_top_n_checkpoints, )