Source code for fortuna.model.model_manager.state

from __future__ import annotations

from typing import (
    Dict,
    Optional,
    Union,
)

from flax.core import FrozenDict

from fortuna.typing import (
    Mutable,
    Params,
)


[docs]class ModelManagerState: params: Params mutable: Optional[Mutable] = None def __init__(self, params: Params, mutable: Optional[Mutable] = None): """ A model manager state class. Parameters ---------- params : Params The random parameters of the probabilistic model. mutable : Optional[Mutable] The mutable objects used to evaluate the models. """ self.params = params self.mutable = mutable
[docs] @classmethod def init_from_dict(cls, d: Union[Dict, FrozenDict]) -> ModelManagerState: """ Initialize the model manager state from a dictionary. This dictionary should be like the output of :func:`~fortuna.model.model_manager.base.ModelManager.init`. Parameters ---------- d : Union[Dict, FrozenDict] A dictionary like the output of :func:`~fortuna.model.model_manager.base.ModelManager.init`. Returns ------- ModelManagerState An model manager state. """ params = FrozenDict( {k: FrozenDict({"params": v["params"]}) for k, v in d.items()} ) mutable = FrozenDict( { k: FrozenDict({_k: _v for _k, _v in v.items() if _k != "params"}) for k, v in d.items() } ) flag = 0 for k, v in mutable.items(): if len(v) > 0: flag += 1 if flag == 0: mutable = None return cls(params=params, mutable=mutable)