Source code for fortuna.model.model_manager.base

import abc
from typing import (
    Dict,
    Mapping,
    Optional,
    Tuple,
    Union,
)

from flax import linen as nn
from flax.core import FrozenDict
from flax.training.checkpoints import PyTree
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

from fortuna.typing import (
    InputData,
    Mutable,
    Params,
)
from fortuna.utils.random import WithRNG


[docs]class ModelManager(WithRNG, abc.ABC): """ Abstract model manager class. It orchestrates the forward pass of the models in the probabilistic model. """ def __init__(self, model: nn.Module, model_editor: Optional[nn.Module] = None): self.model = model self.model_editor = model_editor
[docs] @abc.abstractmethod def apply( self, params: Params, inputs: InputData, mutable: Optional[Mutable] = None, train: bool = False, rng: Optional[PRNGKeyArray] = None, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]: """ Apply the models' forward pass. Parameters ---------- params : Params The random parameters of the probabilistic model. inputs : InputData Input data points. mutable : Optional[Mutable] The mutable objects used to evaluate the models. train : bool Whether the method is called during training. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. Returns ------- Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]] The output of the model manager for each input. Mutable objects may also be returned. """ pass
[docs] @abc.abstractmethod def init( self, input_shape: Tuple[int, ...], rng: Optional[PRNGKeyArray] = None, **kwargs ) -> Dict[str, Mapping]: """ Initialize random parameters and mutable objects. Parameters ---------- input_shape : Tuple The shape of the input variable. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. Returns ------- Dict[str, FrozenDict] Initialized random parameters and mutable objects. """ pass