Flax Linen Module¶
Base class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7
dataclasses. Since
dataclasses take over __init__, you should instead override setup(),
which is automatically called to initialize the module.
Modules can contain submodules, and in this way can be nested in a tree
structure. Submodels can be assigned as regular attributes inside the
setup() method.
You can define arbitrary “forward pass” methods on your Module subclass.
While no methods are special-cased, __call__ is a popular choice because
it allows you to use module instances as if they are functions:
>>> from flax import linen as nn
>>> from typing import Tuple
>>> class Module(nn.Module):
... features: Tuple[int, ...] = (16, 4)
... def setup(self):
... self.dense1 = nn.Dense(self.features[0])
... self.dense2 = nn.Dense(self.features[1])
... def __call__(self, x):
... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules
definitions are co-located with their usage, you can use the
compact() wrapper.