Source code for fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state
from __future__ import annotations
from typing import (
Dict,
List,
Optional,
Tuple,
)
import jax.numpy as jnp
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import (
AnyKey,
Array,
OptaxOptimizer,
)
from fortuna.utils.strings import (
convert_string_to_jnp_array,
encode_tuple_of_lists_of_strings_to_numpy,
)
[docs]class SGHMCState(PosteriorState):
"""
Attributes
----------
encoded_name: jnp.ndarray
SGHMC state name encoded as an array.
"""
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
[docs] @classmethod
def convert_from_map_state(
cls,
map_state: MAPState,
optimizer: OptaxOptimizer,
which_params: Tuple[List[AnyKey], ...],
) -> SGHMCState:
"""
Convert a MAP state into an SGHMC state.
Parameters
----------
map_state: MAPState
A MAP posterior state.
optimizer: OptaxOptimizer
An Optax optimizer.
which_params: Tuple[List[AnyKey], ...]
Sequences of keys pointing to the stochastic parameters.
Returns
-------
SGHMCState
An SGHMC state.
"""
_encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params)
return cls.init(
params=map_state.params,
mutable=map_state.mutable,
optimizer=optimizer,
calib_params=map_state.calib_params,
calib_mutable=map_state.calib_mutable,
_encoded_which_params=_encoded_which_params,
)