Source code for fortuna.prob_model.posterior.state

from __future__ import annotations

from typing import (
    Any,
    Dict,
    Optional,
)

from flax.core import FrozenDict
from flax.training import dynamic_scale
import jax.numpy as jnp

from fortuna.training.train_state import TrainState
from fortuna.typing import (
    CalibMutable,
    CalibParams,
    Mutable,
    OptaxOptimizer,
    Params,
)
from fortuna.utils.strings import convert_string_to_tuple


[docs]class PosteriorState(TrainState): """ A posterior distribution state. This includes all the parameters and mutable objects that characterize an approximation of the posterior distribution. """ params: Params mutable: Optional[Mutable] = None calib_params: Optional[CalibParams] = None calib_mutable: Optional[CalibMutable] = None grad_accumulated: Optional[jnp.ndarray] = None dynamic_scale: Optional[dynamic_scale.DynamicScale] = None encoded_name: tuple = convert_string_to_tuple("PosteriorState")
[docs] @classmethod def init( cls, params: Params, mutable: Optional[Mutable] = None, optimizer: Optional[OptaxOptimizer] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, grad_accumulated: Optional[jnp.ndarray] = None, dynamic_scale: Optional[dynamic_scale.DynamicScale] = None, **kwargs, ) -> Any: """ Initialize a posterior distribution state. Parameters ---------- params : Params The parameters characterizing an approximation of the posterior distribution. optimizer : Optional[OptaxOptimizer] An Optax optimizer associated with the posterior state. mutable : Optional[Mutable] The mutable objects characterizing an approximation of the posterior distribution. calib_params : Optional[CalibParams] The parameters objects characterizing an approximation of the posterior distribution. calib_mutable : Optional[CalibMutable] The calibration mutable objects characterizing an approximation of the posterior distribution. grad_accumulated : Optional[jnp.ndarray] The gradients accumulated in consecutive training steps (used only when `gradient_accumulation_steps > 1`). dynamic_scale: Optional[dynamic_scale.DynamicScale] Dynamic loss scaling for mixed precision gradients. Returns ------- Any A posterior distribution state. """ return cls( apply_fn=None, params=params, opt_state=( kwargs["opt_state"] if optimizer is None and "opt_state" in kwargs else optimizer.init(params) ), mutable=mutable, step=kwargs.get("step", 0), tx=optimizer, calib_params=calib_params, calib_mutable=calib_mutable, dynamic_scale=dynamic_scale, grad_accumulated=grad_accumulated, **{ k: v for k, v in kwargs.items() if k not in ["opt_state", "apply_fn", "tx", "step"] }, )
[docs] @classmethod def init_from_dict( cls, d: Dict, optimizer: Optional[OptaxOptimizer] = None, **kwargs, ) -> PosteriorState: """ Initialize a posterior distribution state from a dictionary. Parameters ---------- d : Dict A dictionary including attributes of the posterior state. optimizer : Optional[OptaxOptimizer] An optax optimizer to assign to the posterior state. Returns ------- PosteriorState A posterior state. """ kwargs = { **kwargs, **{ k: v for k, v in d.items() if k not in [ "params", "mutable", "optimizer", "calib_params", "calib_mutable", ] }, } return cls.init( FrozenDict(d["params"]), FrozenDict(d["mutable"]) if d["mutable"] is not None else None, optimizer, ( FrozenDict(d.get("calib_params")) if d["calib_params"] is not None else None ), ( FrozenDict(d.get("calib_mutable")) if d["calib_mutable"] is not None else None ), **kwargs, )