Source code for fortuna.prob_model.posterior.map.map_posterior

import logging
from typing import Optional

from jax._src.prng import PRNGKeyArray

from fortuna.data.loader import DataLoader
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
from fortuna.prob_model.joint.state import JointState
from fortuna.prob_model.posterior.base import Posterior
from fortuna.prob_model.posterior.map import MAP_NAME
from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.map.map_trainer import (
    JittedMAPTrainer,
    MAPTrainer,
    MultiDeviceMAPTrainer,
)
from fortuna.prob_model.posterior.posterior_state_repository import (
    PosteriorStateRepository,
)
from fortuna.typing import Status
from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype
from fortuna.utils.device import select_trainer_given_devices

logger = logging.getLogger(__name__)


[docs]class MAPPosterior(Posterior): def __init__( self, joint: Joint, posterior_approximator: MAPPosteriorApproximator, ): """ Maximum-a-Posteriori (MAP) approximate posterior class. Parameters ---------- joint: Joint A Joint distribution object. posterior_approximator: MAPPosteriorApproximator A MAP posterior approximator. """ super().__init__(joint=joint, posterior_approximator=posterior_approximator) def __str__(self): return MAP_NAME
[docs] def fit( self, train_data_loader: DataLoader, val_data_loader: Optional[DataLoader] = None, fit_config: FitConfig = FitConfig(), map_fit_config=None, **kwargs, ) -> Status: super()._checks_on_fit_start(fit_config, map_fit_config) trainer_cls = select_trainer_given_devices( devices=fit_config.processor.devices, base_trainer_cls=MAPTrainer, jitted_trainer_cls=JittedMAPTrainer, multi_device_trainer_cls=MultiDeviceMAPTrainer, disable_jit=fit_config.processor.disable_jit, ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, early_stopping_monitor=fit_config.monitor.early_stopping_monitor, early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta, early_stopping_patience=fit_config.monitor.early_stopping_patience, freeze_fun=fit_config.optimizer.freeze_fun, ) if super()._is_state_available_somewhere(fit_config): state = self._restore_state_from_somewhere( fit_config=fit_config, allowed_states=(MAPState,), ) else: state = self._init_state( data_loader=train_data_loader, fit_config=fit_config ) state = super()._freeze_optimizer_in_state(state, fit_config) self._check_state(state) logging.info("Run MAP.") state, status = trainer.train( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, training_dataloader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, validation_dataloader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, verbose=fit_config.monitor.verbose, callbacks=fit_config.callbacks, max_grad_norm=fit_config.hyperparameters.max_grad_norm, gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps, ) self.state = PosteriorStateRepository( fit_config.checkpointer.save_checkpoint_dir if fit_config.checkpointer.dump_state is True else None ) self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status
[docs] def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: state = self.state.get() return JointState( params=state.params, mutable=state.mutable, calib_params=state.calib_params, calib_mutable=state.calib_mutable, )
def _init_state(self, data_loader: DataLoader, fit_config: FitConfig) -> MAPState: state = super()._init_joint_state(data_loader=data_loader) return MAPState.init( params=state.params, mutable=state.mutable, optimizer=fit_config.optimizer.method, calib_params=state.calib_params, calib_mutable=state.calib_mutable, dynamic_scale=get_dynamic_scale_instance_from_model_dtype( getattr(self.joint.likelihood.model_manager.model, "dtype") if hasattr(self.joint.likelihood.model_manager.model, "dtype") else None ), )