Source code for fortuna.model.constant

from typing import Optional

import flax.linen as nn
from flax.linen.initializers import Initializer
import jax.numpy as jnp

from fortuna.typing import Array


[docs]class ConstantModel(nn.Module): r""" A constant model, that is :math:`f(\theta, x) = \theta`. Parameters ---------- output_dim: int The output model dimension. initializer_fun: Optional[Initializer] Function to initialize the model parameters. This must be one of the available options in :code:`flax.linen.initializers`. """ output_dim: int initializer_fun: Optional[Initializer] = nn.initializers.zeros @nn.compact def __call__(self, x: Array, **kwargs) -> jnp.ndarray: constant = self.param("constant", self.initializer_fun, (self.output_dim,)) return jnp.broadcast_to(constant, shape=(x.shape[0], self.output_dim))