Source code for fortuna.model.lenet
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
from fortuna.typing import Array
[docs]class LeNet5(nn.Module):
"""
A LeNet-5 network [LeCun et al., 1989](http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf). Please refer to
:class:`~fortuna.prob_model.model.base.Model` for the internal methods.
Attributes
----------
output_dim: int
The output model dimension.
dtype: Any
Layers' dtype.
"""
output_dim: int
dtype: Any = jnp.float32
def setup(self):
self.dfe_subnet = LeNet5DeepFeatureExtractorSubNet(dtype=self.dtype)
self.output_subnet = LeNet5OutputSubNet(
output_dim=self.output_dim, dtype=self.dtype
)
def __call__(self, x: Array, **kwargs) -> jnp.ndarray:
"""
Forward pass.
Parameters
----------
x: Array
Inputs.
jnp.ndarray
Model outputs.
"""
x = self.dfe_subnet(x)
x = self.output_subnet(x)
return x
class LeNet5DeepFeatureExtractorSubNet(nn.Module):
"""
Deep feature extractor sub-network of a LeNet-5.
Attributes
----------
dtype: Any
Layers' dtype.
"""
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: Array):
"""
Forward pass.
Parameters
----------
x: Array
Inputs.
jnp.ndarray
Output of the hidden layers.
"""
x = nn.Conv(
features=6,
kernel_size=(5, 5),
strides=(1, 1),
padding="valid",
dtype=self.dtype,
)(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(
features=16,
kernel_size=(5, 5),
strides=(1, 1),
padding="valid",
dtype=self.dtype,
)(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=120, dtype=self.dtype)(x)
x = nn.relu(x)
x = nn.Dense(features=84, dtype=self.dtype)(x)
return x
class LeNet5OutputSubNet(nn.Module):
"""
Output sub-network of a LeNet-5.
Attributes
----------
output_dim: int
The output model dimension.
dtype: Any
Layers' dtype.
"""
output_dim: int
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray):
"""
Forward pass.
Parameters
----------
x: jnp.ndarray
Outputs of the deep feature extractor sub-network.
Returns
-------
jnp.ndarray
Model outputs.
"""
x = nn.relu(x)
x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
return x