Source code for fortuna.prob_model.posterior.swag.swag_posterior

from __future__ import annotations

import logging
from typing import Optional

from flax.core import FrozenDict
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

from fortuna.data.loader import (
    DataLoader,
    InputsLoader,
)
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
from fortuna.prob_model.joint.state import JointState
from fortuna.prob_model.posterior.base import Posterior
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.posterior_state_repository import (
    PosteriorStateRepository,
)
from fortuna.prob_model.posterior.run_preliminary_map import run_preliminary_map
from fortuna.prob_model.posterior.swag import SWAG_NAME
from fortuna.prob_model.posterior.swag.swag_approximator import (
    SWAGPosteriorApproximator,
)
from fortuna.prob_model.posterior.swag.swag_state import SWAGState
from fortuna.prob_model.posterior.swag.swag_trainer import (
    JittedSWAGTrainer,
    MultiDeviceSWAGTrainer,
    SWAGTrainer,
)
from fortuna.typing import (
    Array,
    Status,
)
from fortuna.utils.device import select_trainer_given_devices
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import (
    nested_get,
    nested_set,
)
from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array


[docs]class SWAGPosterior(Posterior): def __init__(self, joint: Joint, posterior_approximator: SWAGPosteriorApproximator): """ SWAG approximate posterior class. Parameters ---------- joint: Joint A joint distribution object. posterior_approximator: SWAGPosteriorApproximator A SWAG posterior approximator. """ super().__init__(joint=joint, posterior_approximator=posterior_approximator) def __str__(self): return SWAG_NAME
[docs] def fit( self, train_data_loader: DataLoader, val_data_loader: Optional[DataLoader] = None, fit_config: FitConfig = FitConfig(), map_fit_config: Optional[FitConfig] = None, **kwargs, ) -> Status: super()._checks_on_fit_start(fit_config, map_fit_config) if self.posterior_approximator.rank < 2: raise ValueError("`rank` must be at least 2.") if fit_config.optimizer.n_epochs <= self.posterior_approximator.rank: raise ValueError( """Not enough SWAG epochs to obtain `rank={}`. Please either increase `n_swag_epochs` or decrease `rank`.""".format( self.posterior_approximator.rank ) ) if ( fit_config.monitor.early_stopping_patience and fit_config.monitor.early_stopping_patience > 0 ): logging.warning( f"""It seems you are trying to enable early stopping, since `fit_config.monitor.early_stopping_patience={fit_config.monitor.early_stopping_patience}`. We do not support early stopping in SWAG, since we implement it as a post-processing step of MAP. If your intention was rather to enable early stopping in MAP, please configure `map_fit_config` accordingly.""" ) status = dict() if super()._is_state_available_somewhere(fit_config): state = super()._restore_state_from_somewhere( fit_config=fit_config, allowed_states=(MAPState, SWAGState) ) elif super()._should_run_preliminary_map(fit_config, map_fit_config): state, status["map"] = run_preliminary_map( joint=self.joint, train_data_loader=train_data_loader, val_data_loader=val_data_loader, map_fit_config=map_fit_config, rng=self.rng, **kwargs, ) else: raise ValueError( "The SWAG approximation must start from a preliminary run of MAP or an existing " "checkpoint or state. Please configure `map_fit_config`, or " "`fit_config.checkpointer.restore_checkpoint_path`, " "or `fit_config.checkpointer.start_from_current_state`." ) state = SWAGState.convert_from_map_state( map_state=state, optimizer=fit_config.optimizer.method, ) state = super()._freeze_optimizer_in_state(state, fit_config) if fit_config.optimizer.freeze_fun is not None: which_params = get_trainable_paths( state.params, fit_config.optimizer.freeze_fun ) else: which_params = None trainer_cls = select_trainer_given_devices( devices=fit_config.processor.devices, base_trainer_cls=SWAGTrainer, jitted_trainer_cls=JittedSWAGTrainer, multi_device_trainer_cls=MultiDeviceSWAGTrainer, disable_jit=fit_config.processor.disable_jit, ) trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, save_every_n_steps=fit_config.checkpointer.save_every_n_steps, keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, early_stopping_verbose=False, freeze_fun=fit_config.optimizer.freeze_fun, which_params=which_params, ) kwargs = dict(rank=self.posterior_approximator.rank) logging.info("Run SWAG.") state, status["swag"] = trainer.train( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, training_dataloader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, validation_dataloader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, verbose=fit_config.monitor.verbose, callbacks=fit_config.callbacks, **kwargs, ) self.state = PosteriorStateRepository( fit_config.checkpointer.save_checkpoint_dir if fit_config.checkpointer.dump_state is True else None ) self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints) logging.info("Fit completed.") return status
[docs] def sample( self, rng: Optional[PRNGKeyArray] = None, inputs_loader: Optional[InputsLoader] = None, inputs: Optional[Array] = None, **kwargs, ) -> JointState: """ Sample from the posterior distribution. Parameters ---------- rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. inputs_loader: Optional[InputsLoader] Input data loader. This or `inputs` is required if the posterior state includes mutable objects. inputs: Optional[Array] Input variables. This or `inputs_loader` is required if the posterior state includes mutable objects. Returns ------- JointState A sample from the posterior distribution. """ if rng is None: rng = self.rng.get() state = self.state.get() if state.mutable is not None and inputs_loader is None and inputs is None: raise ValueError( "The posterior state contains mutable objects. Please pass `inputs_loader` or `inputs`." ) n_params = len(state.mean) rank = state.dev.shape[-1] which_params = decode_encoded_tuple_of_lists_of_strings_to_array( state._encoded_which_params ) unravel = ravel_pytree( state.params if which_params is None else [nested_get(state.params, path) for path in which_params] )[1] coeff1 = 1 / jnp.sqrt(2) coeff2 = coeff1 / jnp.sqrt(rank) rng, key1, key2 = random.split(rng, 3) z1 = random.normal(key1, shape=(n_params,)) z2 = random.normal(key2, shape=(rank,)) if which_params is None: state = state.replace( params=self._get_sample( mean=state.mean, std=state.std, dev=state.dev, z1=z1, z2=z2, coeff1=coeff1, coeff2=coeff2, unravel=unravel, ) ) else: state = state.replace( params=FrozenDict( nested_set( d=state.params.unfreeze(), key_paths=which_params, objs=tuple( self._get_sample( mean=state.mean, std=state.std, dev=state.dev, z1=z1, z2=z2, coeff1=coeff1, coeff2=coeff2, unravel=unravel, ) ), ) ) ) if state.mutable: if inputs_loader is not None: for batch_inputs in inputs_loader: state = state.replace( mutable=self.joint.likelihood.model_manager.apply( state.params, batch_inputs, mutable=state.mutable, train=True, rng=rng, )[1]["mutable"] ) else: state = state.replace( mutable=self.joint.likelihood.model_manager.apply( state.params, inputs, mutable=state.mutable, train=True, rng=rng )[1]["mutable"] ) return JointState( params=state.params, mutable=state.mutable, calib_params=state.calib_params, calib_mutable=state.calib_mutable, )
def _get_sample(self, mean, std, dev, z1, z2, coeff1, coeff2, unravel): return unravel(mean + coeff1 * std * z1 + coeff2 * jnp.matmul(dev, z2)) def _get_mean_std_dev(self, state: SWAGState) -> SWAGState: var = state._mean_squared_rav_params - state._mean_rav_params**2 var = jnp.maximum(var, 0.0) return state.update( dict( mean=state._mean_rav_params, std=jnp.sqrt(var), dev=state._deviation_rav_params, ) )