Flax Linen Module

Base class for all neural network modules.

Layers and models should subclass this class.

All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over __init__, you should instead override setup(), which is automatically called to initialize the module.

Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the setup() method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice because it allows you to use module instances as if they are functions:

>>> from flax import linen as nn
>>> from typing import Tuple

>>> class Module(nn.Module):
...   features: Tuple[int, ...] = (16, 4)

...   def setup(self):
...     self.dense1 = nn.Dense(self.features[0])
...     self.dense2 = nn.Dense(self.features[1])

...   def __call__(self, x):
...     return self.dense2(nn.relu(self.dense1(x)))

Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the compact() wrapper.