Flax Linen Module#
Base class for all neural network modules. Layers and models should subclass this class.
All Flax Modules are Python 3.7
dataclasses. Since
dataclasses take over __init__
, you should instead override setup()
,
which is automatically called to initialize the module.
Modules can contain submodules, and in this way can be nested in a tree
structure. Submodels can be assigned as regular attributes inside the
setup()
method.
You can define arbitrary “forward pass” methods on your Module subclass.
While no methods are special-cased, __call__
is a popular choice because
it allows you to use module instances as if they are functions:
from flax import linen as nn
class Module(nn.Module):
features: Tuple[int, ...] = (16, 4)
def setup(self):
self.dense1 = Dense(self.features[0])
self.dense2 = Dense(self.features[1])
def __call__(self, x):
return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules
definitions are co-located with their usage, you can use the
compact()
wrapper.