Source code for fortuna.prob_model.posterior.laplace.laplace_state

from __future__ import annotations

from typing import (
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)

from flax.core import FrozenDict

from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import (
    AnyKey,
    Array,
    Params,
)
from fortuna.utils.nested_dicts import nested_pair
from fortuna.utils.strings import (
    convert_string_to_tuple,
    encode_tuple_of_lists_of_strings_to_numpy,
)


[docs] class LaplaceState(PosteriorState): """ Attributes ---------- prior_log_var: float Prior log-variance value. encoded_name: jnp.ndarray Laplace state name encoded as an array. """ prior_log_var: float = 0.0 encoded_name: tuple = convert_string_to_tuple("LaplaceState") _encoded_which_params: Optional[Dict[str, Array]] = None
[docs] @classmethod def convert_from_map_state( cls, map_state: MAPState, hess_lik_diag: Union[Params, Tuple[Params, ...]], prior_log_var: Optional[float], which_params: Tuple[List[AnyKey], ...], ) -> LaplaceState: """ Convert a MAP state into a Laplace state. Parameters ---------- map_state: MAPState A MAP state. hess_lik_diag: Union[Params, Tuple[Params, ...]] Diagonal of the approximated Hessian of the likelihood. prior_log_var: float Prior log-variance value. If None, initialize it to 100. which_params: Tuple[List[AnyKey], ...] Sequences of keys pointing to the parameters over which `std` is defined. If `which_params` is None, `std` must be defined for all parameters. Returns ------- LaplaceState A Laplace state instance. """ params = map_state.params.unfreeze() if which_params is not None: params = nested_pair( d=params, key_paths=which_params, objs=hess_lik_diag, labels=("mean", "hess_lik_diag"), ) else: for k, v in params.items(): params[k] = FrozenDict( { "params": dict( mean=v["params"], hess_lik_diag=hess_lik_diag[k]["params"] ) } ) d = vars( map_state.replace( params=FrozenDict(params), encoded_name=LaplaceState.encoded_name ) ) d["_encoded_which_params"] = encode_tuple_of_lists_of_strings_to_numpy( which_params ) d["prior_log_var"] = prior_log_var if prior_log_var is not None else 0.0 return LaplaceState.init_from_dict(d)