import abc
import logging
from typing import (
Callable,
Dict,
Optional,
)
import jax
import jax.numpy as jnp
from fortuna.data.loader import DataLoader
from fortuna.output_calib_model.state import OutputCalibState
from fortuna.prob_model.calib_config.base import CalibConfig
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.prob_model_calibrator import (
JittedProbModelOutputCalibrator,
MultiDeviceProbModelOutputCalibrator,
ProbModelOutputCalibrator,
)
from fortuna.typing import (
Array,
Path,
Status,
)
from fortuna.utils.data import check_data_loader_is_not_random
from fortuna.utils.device import select_trainer_given_devices
from fortuna.utils.random import RandomNumberGenerator
[docs]class ProbModel(abc.ABC):
"""
Abstract probabilistic model class.
"""
def __init__(self, seed: int = 0):
self.rng = RandomNumberGenerator(seed=seed)
self.__set_rng()
def __set_rng(self):
self.model_manager.rng = self.rng
self.output_calib_manager.rng = self.rng
self.prob_output_layer.rng = self.rng
self.prior.rng = self.rng
self.likelihood.rng = self.rng
self.joint.rng = self.rng
self.posterior.rng = self.rng
self.predictive.rng = self.rng
[docs] def train(
self,
train_data_loader: DataLoader,
val_data_loader: Optional[DataLoader] = None,
calib_data_loader: Optional[DataLoader] = None,
fit_config: FitConfig = FitConfig(),
calib_config: CalibConfig = CalibConfig(),
map_fit_config: Optional[FitConfig] = None,
) -> Dict[str, Status]:
"""
Train the probabilistic model. This involves fitting the posterior distribution and calibrating the
probabilistic model. Calibration is performed only if (1) `calib_data_loader` is passed and (2) the
probabilistic model contains any calibrator.
Parameters
----------
train_data_loader : DataLoader
A training data loader.
val_data_loader : DataLoader
A validation data loader. This is used to validate both posterior fitting and calibration.
calib_data_loader : DataLoader
A calibration data loader. If this is not passed, no calibration is performed.
fit_config : FitConfig
An object to configure the posterior distribution fitting.
calib_config : CalibConfig
An object to configure the calibration.
map_fit_config : Optional[FitConfig] = None
An object to configure a preliminary posterior distribution fitting via the Maximum-A-Posteriori (MAP)
method.
The fit methods of several supported posterior approximations, like the ones of
:class:`~fortuna.prob_model.posterior.swag.swag_posterior.SWAGPosterior` and
:class:`~fortuna.prob_model.posterior.laplace.laplace_posterior.LaplacePosterior`, start from a preliminary
run of MAP, which can be configured via this object. If the method does not start from MAP, this argument is
ignored.
Returns
-------
Dict[str, Status]
Status objects for both posterior fitting and calibration.
"""
logging.info("Fit the posterior distribution...")
fit_status = self.posterior.fit(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
fit_config=fit_config,
map_fit_config=map_fit_config,
)
calib_status = None
if calib_data_loader:
calib_status = self.calibrate(
calib_data_loader=calib_data_loader,
val_data_loader=val_data_loader,
calib_config=calib_config,
)
logging.info("Calibration completed.")
return dict(fit_status=fit_status, calib_status=calib_status)
def _calibrate(
self,
calib_data_loader: DataLoader,
uncertainty_fn: Callable[[jnp.ndarray, jnp.ndarray, Array], jnp.ndarray],
val_data_loader: Optional[DataLoader] = None,
calib_config: CalibConfig = CalibConfig(),
) -> Status:
check_data_loader_is_not_random(calib_data_loader)
if val_data_loader is not None:
check_data_loader_is_not_random(val_data_loader)
if (
self.output_calib_manager is None
or self.output_calib_manager.output_calibrator is None
):
logging.warning(
"""Nothing to calibrate. No calibrator was passed to the probabilistic model."""
)
else:
if self.posterior.state is None:
raise ValueError(
"""Before calibration, you must either train the probabilistic model (see
:meth:`~fortuna.prob_model.base.ProbModel.train`),
or load a state from an existing checkpoint
(see :meth:`~fortuna.prob_model.base.ProbModel.load_state`)."""
)
if calib_config.monitor.verbose:
logging.info(
"Pre-compute ensemble of outputs on the calibration data loader."
)
distribute = jax.local_devices()[0].platform != "cpu"
(
calib_ensemble_outputs_loader,
calib_size,
) = self.predictive._sample_outputs_loader(
inputs_loader=calib_data_loader.to_inputs_loader(),
n_output_samples=calib_config.processor.n_posterior_samples,
return_size=True,
distribute=distribute,
)
if calib_config.monitor.verbose:
logging.info(
"Pre-compute ensemble of outputs on the validation data loader."
)
val_ensemble_outputs_loader, val_size = (
self.predictive._sample_outputs_loader(
inputs_loader=val_data_loader.to_inputs_loader(),
n_output_samples=calib_config.processor.n_posterior_samples,
return_size=True,
distribute=distribute,
)
if val_data_loader is not None
else (None, None)
)
trainer_cls = select_trainer_given_devices(
devices=calib_config.processor.devices,
base_trainer_cls=ProbModelOutputCalibrator,
jitted_trainer_cls=JittedProbModelOutputCalibrator,
multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
disable_jit=calib_config.processor.disable_jit,
)
calibrator = trainer_cls(
calib_outputs_loader=calib_ensemble_outputs_loader,
val_outputs_loader=val_ensemble_outputs_loader,
predict_fn=self.prob_output_layer.predict,
uncertainty_fn=uncertainty_fn,
save_checkpoint_dir=calib_config.checkpointer.save_checkpoint_dir,
save_every_n_steps=calib_config.checkpointer.save_every_n_steps,
keep_top_n_checkpoints=calib_config.checkpointer.keep_top_n_checkpoints,
disable_training_metrics_computation=calib_config.monitor.disable_calibration_metrics_computation,
eval_every_n_epochs=calib_config.monitor.eval_every_n_epochs,
early_stopping_monitor=calib_config.monitor.early_stopping_monitor,
early_stopping_min_delta=calib_config.monitor.early_stopping_min_delta,
early_stopping_patience=calib_config.monitor.early_stopping_patience,
)
if calib_config.checkpointer.restore_checkpoint_path is None:
calib_dict = self.posterior.state.extract_calib_keys()
state = OutputCalibState.init(
params=calib_dict["calib_params"],
mutable=calib_dict["calib_mutable"],
optimizer=calib_config.optimizer.method,
)
else:
state = self.posterior.restore_checkpoint(
calib_config.checkpointer.restore_checkpoint_path,
optimizer=calib_config.optimizer.method,
)
if calib_config.monitor.verbose:
logging.info("Start calibration.")
state, status = calibrator.train(
rng=self.rng.get(),
state=state,
loss_fun=self.predictive._batched_negative_log_joint_prob,
training_data_loader=calib_data_loader,
training_dataset_size=calib_size,
n_epochs=calib_config.optimizer.n_epochs,
metrics=calib_config.monitor.metrics,
val_data_loader=val_data_loader,
val_dataset_size=val_size,
verbose=calib_config.monitor.verbose,
)
self.posterior.state.update(
variables=dict(calib_params=state.params, calib_mutable=state.mutable)
)
if (
calib_config.checkpointer.dump_state
and calib_config.checkpointer.save_checkpoint_dir is not None
):
if calib_config.monitor.verbose:
logging.info("Dump state to disk.")
self.save_state(
checkpoint_path=calib_config.checkpointer.save_checkpoint_dir
)
if calib_config.monitor.verbose:
logging.info("Calibration completed.")
return status
[docs] def load_state(self, checkpoint_path: Path) -> None:
"""
Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the
probabilistic model.
Parameters
----------
checkpoint_path : Path
Path to a checkpoint file or directory to restore.
"""
return self.posterior.load_state(checkpoint_path)
[docs] def save_state(
self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1
) -> None:
"""
Save the posterior distribution state as a checkpoint.
Parameters
----------
checkpoint_path : Path
Path to file or directory where to save the current state.
keep_top_n_checkpoints : int
Number of past checkpoint files to keep.
"""
return self.posterior.save_state(checkpoint_path, keep_top_n_checkpoints)