Source code for fortuna.model_editor.base
from typing import (
Any,
Callable,
Dict,
Optional,
Tuple,
Union,
)
import flax.linen as nn
import jax.numpy as jnp
from fortuna.typing import (
InputData,
Mutable,
)
[docs]class ModelEditor(nn.Module):
@nn.compact
def __call__(
self,
apply_fn: Callable[
[Dict, InputData], Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]
],
model_params: Dict,
x: Any,
has_aux: bool,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]:
"""
Apply a transformation to the forward pass.
Parameters
----------
apply_fn: Callable[[Dict, InputData], Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]]
The model forward pass.
model_params: Dict
The model parameters.
x: Array
Batch of inputs.
has_aux: bool
Whether the forward pass returns auxiliary objects.
Returns
-------
Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]
Return the transformed outputs, and auxiliary objects if available.
"""
pass