from __future__ import annotations
import logging
from typing import (
Dict,
List,
Optional,
Tuple,
Union,
)
from flax.core import FrozenDict
from flax.training.common_utils import (
shard,
shard_prng_key,
)
import jax
from jax import (
devices,
hessian,
jit,
lax,
pmap,
vjp,
)
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map
import tqdm
from fortuna.data.loader import (
DataLoader,
DeviceDimensionAugmentedLoader,
)
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
from fortuna.prob_model.joint.state import JointState
from fortuna.prob_model.posterior.base import Posterior
from fortuna.prob_model.posterior.laplace import LAPLACE_NAME
from fortuna.prob_model.posterior.laplace.laplace_approximator import (
LaplacePosteriorApproximator,
)
from fortuna.prob_model.posterior.laplace.laplace_state import LaplaceState
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.posterior_state_repository import (
PosteriorStateRepository,
)
from fortuna.prob_model.posterior.run_preliminary_map import run_preliminary_map
from fortuna.prob_model.prior.gaussian import (
DiagonalGaussianPrior,
IsotropicGaussianPrior,
)
from fortuna.typing import (
AnyKey,
CalibMutable,
CalibParams,
Mutable,
Params,
Status,
)
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import (
nested_get,
nested_set,
nested_unpair,
)
from fortuna.utils.random import generate_random_normal_like_tree
from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array
[docs]
class LaplacePosterior(Posterior):
def __init__(
self,
joint: Joint,
posterior_approximator: LaplacePosteriorApproximator,
):
"""
Laplace approximation posterior class.
Parameters
----------
joint: Joint
A joint distribution object.
posterior_approximator: LaplacePosteriorApproximator
A Laplace posterior approximator.
"""
super().__init__(joint=joint, posterior_approximator=posterior_approximator)
if type(joint.prior) not in [DiagonalGaussianPrior, IsotropicGaussianPrior]:
raise ValueError(
"""The Laplace posterior_approximation is not supported for this model. The prior distribution must be one of the
following choices: {}.""".format(
[DiagonalGaussianPrior, IsotropicGaussianPrior]
)
)
def __str__(self):
return LAPLACE_NAME
def _gnn_approx(
self,
params: Params,
train_data_loader: DataLoader,
mutable: Optional[Mutable] = None,
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
which_params: Optional[Tuple[List[AnyKey, ...]]] = None,
factorization: str = "diagonal",
verbose: bool = True,
) -> Params:
"""
Estimate a standard deviation for each parameter using a diagonal Generalized Gauss-Newton Hessian
approximation.
Parameters
----------
params : Params
The random parameters of the probabilistic model.
train_data_loader: DataLoader
A training data loader.
mutable: Optional[Mutable]
Mutable objects.
calib_params : Optional[CalibParams]
The calibration parameters of the probabilistic model.
calib_mutable : Optional[CalibMutable] = None
The calibration mutable objects used to evaluate the calibrators.
which_params : Optional[Tuple[List[AnyKey, ...]]]
Sequences of keys indicating which parameters to compute the Hessian upon.
factorization: str = "diagonal"
Factorization of the GGN approximation. Currently, only "diagonal" is supported.
verbose: bool
Whether to log the training progress.
Returns
-------
Params
An estimate of the likelihood standard deviation for each random parameter.
"""
rav, unravel = ravel_pytree(
tuple([nested_get(params, keys) for keys in which_params])
if which_params
else params
)
def get_params_from_rav(_rav):
unrav = unravel(_rav)
return (
FrozenDict(
nested_set(
params.unfreeze(),
which_params,
unrav,
)
)
if which_params
else unrav
)
def apply_calib_model_manager(_params, _batch_inputs):
outputs = self.joint.likelihood.model_manager.apply(
_params, _batch_inputs, mutable=mutable, train=False
)
outputs = self.joint.likelihood.output_calib_manager.apply(
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
)
return outputs
apply_fn = lambda _rav, x: apply_calib_model_manager(
get_params_from_rav(_rav), x
).squeeze(0)
vjp_fn = lambda x: vjp(lambda _rav: apply_fn(_rav, x), rav)[1]
def eig_hess_fn(vars):
hess = hessian(
lambda __o: self.joint.likelihood.prob_output_layer.log_prob(
__o, vars[1]
)
)(vars[0])
return jnp.linalg.eigh(hess)
def compute_hess_batch(_batch_inputs, _batch_targets):
lam, z = lax.map(
eig_hess_fn,
(apply_calib_model_manager(params, _batch_inputs), _batch_targets),
)
ztj = lax.map(
lambda v: lax.map(vjp_fn(v[0]), v[1].T),
(tree_map(lambda x: x[:, None], _batch_inputs), z),
)[0]
if factorization == "diagonal":
return -jnp.sum(lam[:, :, None] * ztj**2, (0, 1))
raise ValueError(
f"`factorization={factorization}` not recognized. Currently, only "
f"`factorization='diagonal'` is supported."
)
n_gpu_devices = len([d for d in devices() if d.platform == "gpu"])
if n_gpu_devices > 0:
train_data_loader = DeviceDimensionAugmentedLoader(train_data_loader)
compute_hess_batch = pmap(compute_hess_batch, axis_name="batch")
else:
compute_hess_batch = jit(compute_hess_batch)
h = 0.0
for i, (batch_inputs, batch_targets) in enumerate(train_data_loader):
if verbose:
logging.info(f"Hessian approximation for batch {i + 1}.")
h += compute_hess_batch(batch_inputs, batch_targets)
if n_gpu_devices > 0:
h = jnp.sum(h, 0)
return unravel(h)
def _compute_std(self, prior_log_var: float, hess_lik_diag: Params) -> Params:
hess_prior = jnp.exp(-prior_log_var)
hess_lik_diag_rav, unravel = ravel_pytree(hess_lik_diag)
return unravel(1 / jnp.sqrt(hess_prior + hess_lik_diag_rav))
[docs]
def fit(
self,
train_data_loader: DataLoader,
val_data_loader: Optional[DataLoader] = None,
fit_config: FitConfig = FitConfig(),
map_fit_config: Optional[FitConfig] = None,
**kwargs,
) -> Dict[str, Status]:
super()._checks_on_fit_start(fit_config, map_fit_config)
status = dict()
if super()._is_state_available_somewhere(fit_config):
state = super()._restore_state_from_somewhere(
fit_config=fit_config, allowed_states=(MAPState, LaplaceState)
)
elif super()._should_run_preliminary_map(fit_config, map_fit_config):
state, status["map"] = run_preliminary_map(
joint=self.joint,
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
map_fit_config=map_fit_config,
rng=self.rng,
**kwargs,
)
else:
raise ValueError(
"The Laplace approximation must start from a preliminary run of MAP or an existing "
"checkpoint or state. Please configure `map_fit_config`, or "
"`fit_config.checkpointer.restore_checkpoint_path`, "
"or `fit_config.checkpointer.start_from_current_state`."
)
state = self._init_map_state(state=state, fit_config=fit_config)
if fit_config.optimizer.freeze_fun is not None:
which_params = get_trainable_paths(
params=state.params, freeze_fun=fit_config.optimizer.freeze_fun
)
else:
which_params = None
logging.info("Run the Laplace approximation.")
hess_lik_diag = self._gnn_approx(
state.params,
train_data_loader,
mutable=state.mutable,
calib_params=state.calib_params,
calib_mutable=state.calib_mutable,
which_params=which_params,
verbose=fit_config.monitor.verbose,
)
state = LaplaceState.convert_from_map_state(
map_state=state,
hess_lik_diag=hess_lik_diag,
prior_log_var=self.joint.prior.log_var,
which_params=which_params,
)
if fit_config.checkpointer.save_checkpoint_dir:
self.save_checkpoint(
state,
save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir,
keep=fit_config.checkpointer.keep_top_n_checkpoints,
force_save=True,
)
self.state = PosteriorStateRepository(
fit_config.checkpointer.save_checkpoint_dir
)
self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
logging.info("Fit completed.")
if (
val_data_loader is not None
and self.posterior_approximator.tune_prior_log_variance
):
logging.info("Tuning the prior log-variance now")
opt_prior_log_var = self.prior_log_variance_tuning(
val_data_loader=val_data_loader,
n_posterior_samples=5,
distribute=fit_config.processor.devices == -1,
)
state = state.replace(prior_log_var=opt_prior_log_var)
self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
logging.info(f"Best prior log-variance found: {opt_prior_log_var}")
return status
[docs]
def sample(
self,
rng: Optional[jax.Array] = None,
**kwargs,
) -> JointState:
if rng is None:
rng = self.rng.get()
state: LaplaceState = self.state.get()
if kwargs.get("prior_log_var") is not None:
state = state.replace(prior_log_var=kwargs.get("prior_log_var"))
if state._encoded_which_params is not None:
which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
state._encoded_which_params
)
mean, hess_lik_diag = nested_unpair(
state.params.unfreeze(),
which_params,
("mean", "hess_lik_diag"),
)
std = self._compute_std(
prior_log_var=state.prior_log_var, hess_lik_diag=hess_lik_diag
)
noise = generate_random_normal_like_tree(rng, std)
params = nested_set(
d=mean,
key_paths=which_params,
objs=tuple(
[
tree_map(
lambda m, s, e: m + s * e,
nested_get(mean, keys),
nested_get(std, keys),
nested_get(noise, keys),
)
for keys in which_params
]
),
)
for k, v in params.items():
params[k] = FrozenDict(v)
state = state.replace(params=FrozenDict(params))
else:
mean, hess_lik_diag = dict(), dict()
for k, v in state.params.items():
mean[k] = FrozenDict({"params": v["params"]["mean"]})
hess_lik_diag[k] = FrozenDict({"params": v["params"]["hess_lik_diag"]})
std = self._compute_std(
prior_log_var=state.prior_log_var, hess_lik_diag=hess_lik_diag
)
state = state.replace(
params=FrozenDict(
tree_map(
lambda m, s, e: m + s * e,
mean,
std,
generate_random_normal_like_tree(rng, std),
)
)
)
return JointState(
params=state.params,
mutable=state.mutable,
calib_params=state.calib_params,
calib_mutable=state.calib_mutable,
)
def _init_map_state(
self, state: Union[MAPState, LaplaceState], fit_config: FitConfig
) -> MAPState:
if isinstance(state, LaplaceState):
if state._encoded_which_params is not None:
which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
state._encoded_which_params
)
state = state.replace(
params=FrozenDict(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "hess_lik_diag"),
)[0]
)
)
else:
state = state.replace(
params=FrozenDict(
{
k: dict(params=v["params"]["mean"])
for k, v in state.params.items()
}
)
)
state = MAPState.init(
params=state.params,
mutable=state.mutable,
optimizer=fit_config.optimizer.method,
calib_params=state.calib_params,
calib_mutable=state.calib_mutable,
)
return state
def _batched_log_prob(
self,
batch,
prior_log_var: float,
n_posterior_samples: int = 30,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
import jax.random as random
import jax.scipy as jsp
if rng is None:
rng = self.rng.get()
keys = random.split(rng, n_posterior_samples)
def _lik_log_batched_prob(key):
sample = self.sample(inputs=batch[0], rng=key, prior_log_var=prior_log_var)
return self.joint.likelihood._batched_log_prob(
sample.params,
batch,
mutable=sample.mutable,
calib_params=sample.calib_params,
calib_mutable=sample.calib_mutable,
**kwargs,
)
return jsp.special.logsumexp(
lax.map(_lik_log_batched_prob, keys), axis=0
) - jnp.log(n_posterior_samples)
[docs]
def prior_log_variance_tuning(
self,
val_data_loader: DataLoader,
n_posterior_samples: int = 10,
mode: str = "cv",
min_prior_log_var: float = -3,
max_prior_log_var: float = 3,
grid_size: int = 20,
distribute: bool = False,
) -> jnp.ndarray:
if mode == "cv":
return self._prior_log_variance_tuning_cv(
val_data_loader,
n_posterior_samples,
min_prior_log_var,
max_prior_log_var,
grid_size,
distribute,
)
elif mode == "marginal_lik":
raise NotImplementedError(
"Optimizing the prior log variance via marginal likelihood maximization is not yet available."
)
else:
raise ValueError(f"Unrecognized mode={mode} for prior log variance tuning.")
def _prior_log_variance_tuning_cv(
self,
val_data_loader: DataLoader,
n_posterior_samples: int,
min_prior_log_var: float,
max_prior_log_var: float,
grid_size: int,
distribute: bool,
) -> jnp.ndarray:
best = None
candidates = list(
jnp.linspace(min_prior_log_var, max_prior_log_var, grid_size)
) + [jnp.array(self.joint.prior.log_var)]
if distribute:
rng = shard_prng_key(jax.random.PRNGKey(0))
val_data_loader = DeviceDimensionAugmentedLoader(val_data_loader)
candidates = [shard(c) for c in candidates]
fn = pmap(self._batched_log_prob, static_broadcasted_argnums=(2,))
else:
fn = jit(self._batched_log_prob, static_argnums=(2,))
for lpv in tqdm.tqdm(candidates, desc="Tuning prior log-var"):
neg_log_prob = -jnp.sum(
jnp.concatenate(
[
self.joint.likelihood._unshard_array(
fn(batch, lpv, n_posterior_samples, rng)
)
for batch in val_data_loader
],
0,
)
)
if best is None or neg_log_prob < best[-1]:
best = (lpv, neg_log_prob)
opt_prior_log_var = best[0].reshape()
return opt_prior_log_var