Source code for fortuna.output_calib_model.state
from __future__ import annotations
from typing import (
Any,
Dict,
Optional,
Union,
)
from flax.core import FrozenDict
import jax.numpy as jnp
from fortuna.training.train_state import TrainState
from fortuna.typing import (
CalibMutable,
CalibParams,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_tuple
[docs]class OutputCalibState(TrainState):
params: CalibParams
mutable: Optional[CalibMutable] = None
encoded_name: tuple = convert_string_to_tuple("OutputCalibState")
[docs] @classmethod
def init(
cls,
params: CalibParams,
mutable: Optional[CalibMutable] = None,
optimizer: Optional[OptaxOptimizer] = None,
**kwargs,
) -> Any:
"""
Initialize an output calibration state.
Parameters
----------
params : CalibParams
The calibration parameters.
optimizer : Optional[OptaxOptimizer]
An Optax optimizer associated with the calibration state.
mutable : Optional[CalibMutable]
The calibration mutable objects.
Returns
-------
Any
A calibration state.
"""
return cls(
apply_fn=None,
params=params,
opt_state=(
kwargs["opt_state"]
if optimizer is None and "opt_state" in kwargs
else None if optimizer is None else optimizer.init(params)
),
mutable=mutable,
step=kwargs.get("step", 0),
tx=optimizer,
**{
k: v
for k, v in kwargs.items()
if k not in ["opt_state", "apply_fn", "tx", "step"]
},
)
[docs] @classmethod
def init_from_dict(
cls,
d: Union[Dict, FrozenDict],
optimizer: Optional[OptaxOptimizer] = None,
**kwargs,
) -> OutputCalibState:
"""
Initialize a calibration state from a dictionary.
Parameters
----------
d : Union[Dict, FrozenDict]
A dictionary with as keys the calibrators and as values their initializations.
optimizer : Optional[OptaxOptimizer]
An optax optimizer to assign to the calibration state.
Returns
-------
OutputCalibState
A calibration state.
"""
kwargs = {
**kwargs,
**{
k: v
for k, v in d.items()
if k
not in [
"params",
"mutable",
"optimizer",
]
},
}
return cls.init(
FrozenDict(d["params"]),
FrozenDict(d["mutable"]),
optimizer,
**kwargs,
)