Source code for fortuna.prob_model.posterior.swag.swag_state
from __future__ import annotations
from typing import (
Any,
Dict,
List,
Optional,
)
import jax.numpy as jnp
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import (
Array,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_jnp_array
[docs]class SWAGState(PosteriorState):
"""
Attributes
----------
encoded_name: jnp.ndarray
SWAG state name encoded as an array.
mean: Optional[jnp.ndarray]
Mean of the posterior approximation.
std: Optional[jnp.ndarray]
Diagonal standard deviation of the posterior approximation.
dev: Optional[jnp.ndarray]
Deviation term of the covariance matrix of the posterior approximation.
"""
mean: Optional[jnp.ndarray] = None
std: Optional[jnp.ndarray] = None
dev: Optional[jnp.ndarray] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SWAGState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
[docs] @classmethod
def convert_from_map_state(
cls, map_state: MAPState, optimizer: OptaxOptimizer
) -> SWAGState:
"""
Convert a MAP state into a SWAG state.
Parameters
----------
map_state: MAPState
A MAP posterior state.
optimizer: OptaxOptimizer
An Optax optimizer.
Returns
-------
SWAGState
A SWAG state.
"""
return SWAGState.init(
params=map_state.params,
mutable=map_state.mutable,
optimizer=optimizer,
calib_params=map_state.calib_params,
calib_mutable=map_state.calib_mutable,
)
[docs] def update(self, variables: Dict[str, Any]) -> SWAGState:
"""
Update the SWAG state.
Parameters
----------
variables: Dict[str, Any]
The attributes to update and their values.
Returns
-------
SWAGState
Updated SWAG state.
"""
unchanged_keys = {k: v for k, v in vars(self).items() if k not in variables}
return self.replace(**unchanged_keys, **variables)