Source code for fortuna.model.hyper
import flax.linen as nn
import jax.numpy as jnp
from fortuna.typing import Array
[docs]class HyperparameterModel(nn.Module):
r"""
A hyperparameter model. The value of the hyperparameter will not change during training.
Parameters
----------
value: Union[float, Array]
Value of the hyperparameter.
"""
value: Array
def setup(self) -> None:
if self.value.ndim != 1:
raise ValueError(
"`value` must be a one-dimensional array, with length equal to the output dimension of "
"the model."
)
dummy = self.param("none", nn.initializers.zeros, (0,))
def __call__(self, x: Array, **kwargs) -> jnp.ndarray:
return jnp.broadcast_to(self.value, shape=(x.shape[0], len(self.value)))