Source code for fortuna.model.scalar_hyper

import flax.linen as nn
import jax.numpy as jnp

from fortuna.typing import Array


[docs]class ScalarHyperparameterModel(nn.Module): r""" A scalar hyperparameter model. The scalar value of the hyperparameter will not change during training, and it will be broadcasted to the output dimension. Parameters ---------- output_dim: int The output model dimension. value: float Scalar value of the hyperparameter. """ output_dim: int value: float def setup(self) -> None: if type(self.value) != float: raise ValueError( f"`value` must be a float, but a {type(self.value)} was found instead." ) dummy = self.param("none", nn.initializers.zeros, (0,)) def __call__(self, x: Array, **kwargs) -> jnp.ndarray: return jnp.broadcast_to(self.value, shape=(x.shape[0], self.output_dim))