Source code for fortuna.model.wideresnet

"""
Wide ResNet model
(adapted from https://github.com/google/flax/blob/v0.2/examples/cifar10/models/wideresnet.py)
"""
from functools import partial
from typing import (
    Any,
    Callable,
    Tuple,
)

import flax.linen as nn
import jax.numpy as jnp

from fortuna.model.utils.spectral_norm import WithSpectralConv2DNorm
from fortuna.typing import Array

ModuleDef = Any


class WideResnetBlock(nn.Module):
    """
    A wide residual network block.

    Attributes
    ----------
    conv: ModuleDef
        Convolution module.
    norm: ModuleDef
        Normalization module.
    activation: Callable
        Activation function.
    filters: int
        Number of filters.
    strides: Tuple[int, int]
        Strides.
    dropout: ModuleDef
        Dropout module.
    dropout_rate: float
        Dropout rate.
    """

    conv: ModuleDef
    norm: ModuleDef
    activation: Callable
    filters: int
    strides: Tuple[int, int] = (1, 1)
    dropout: ModuleDef = nn.Dropout
    dropout_rate: float = 0.0

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        """
        Block forward pass.

        Parameters
        ----------
        x: jnp.ndarray
            Block inputs.
        train: bool
            Whether the call is performed during training.

        Returns
        -------
        jnp.ndarray
            Block outputs.
        """
        dropout = self.dropout(rate=self.dropout_rate, broadcast_dims=(1, 2))

        y = self.norm(name="bn1")(x)
        y = nn.relu(y)
        if self.dropout_rate > 0.0:
            y = dropout(y, deterministic=not train)
        y = self.conv(self.filters, (3, 3), self.strides, name="conv1")(y)
        y = self.norm(name="bn2")(y)
        y = nn.relu(y)
        if self.dropout_rate > 0.0:
            y = dropout(y, deterministic=not train)
        y = self.conv(self.filters, (3, 3), name="conv2")(y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != self.filters) or self.strides != (1, 1):
            x = self.conv(self.filters, (3, 3), self.strides)(x)
        return x + y


class WideResnetGroup(nn.Module):
    """
    A wide residual network group.

    Attributes
    ----------
    conv: ModuleDef
        Convolution module.
    norm: ModuleDef
        Normalization module.
    activation: Callable
        Activation function.
    blocks_per_group: int
        Number of blocks per group.
    strides: Tuple[int, int]
        Strides.
    dropout: ModuleDef
        Dropout module.
    dropout_rate: float
        Dropout rate.
    """

    conv: ModuleDef
    norm: ModuleDef
    activation: Callable
    blocks_per_group: int
    filters: int
    strides: Tuple[int, int] = (1, 1)
    dropout: ModuleDef = nn.Dropout
    dropout_rate: float = 0.0

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        """
        Group forward pass.

        Parameters
        ----------
        x: jnp.ndarray
            Group inputs.
        train: bool
            Whether the call is performed during training.

        Returns
        -------
        jnp.ndarray
            Group outputs.
        """
        for i in range(self.blocks_per_group):
            x = WideResnetBlock(
                conv=self.conv,
                norm=self.norm,
                activation=self.activation,
                filters=self.filters,
                strides=self.strides if i == 0 else (1, 1),
                dropout=self.dropout,
                dropout_rate=self.dropout_rate,
            )(x, train=train)
        return x


class DeepFeatureExtractorSubNet(nn.Module):
    """
    Deep feature extractor subnetwork.

    Attributes
    ----------
    depth: int
        Depth of the subnetwork.
    widen_factor: int
        Widening factor.
    dropout_rate: float
        Dropout rate.
    dtype: Any
        Layers' dtype.
    activation: Callable
        Activation function.
    conv: ModuleDef
        Convolution module.
    dropout: ModuleDef
        Dropout module.
    """

    depth: int = 28
    widen_factor: int = 10
    dropout_rate: float = 0.0
    dtype: Any = jnp.float32
    activation: Callable = nn.relu
    conv: ModuleDef = nn.Conv
    dropout: ModuleDef = nn.Dropout

    @nn.compact
    def __call__(self, x: Array, train: bool = True) -> jnp.ndarray:
        """
        Deep feature extractor subnetwork forward pass.

        Parameters
        ----------
        x: Array
            Input data.
        train: bool
            Whether the call is performed during training.

        Returns
        -------
        jnp.ndarray
            Deep feature extractor representation.
        """
        if hasattr(self, "spectral_norm"):
            conv = self.spectral_norm(self.conv, train=train)
        else:
            conv = self.conv

        blocks_per_group = (self.depth - 4) // 6

        dropout = self.dropout(rate=self.dropout_rate, broadcast_dims=(1, 2))
        conv = partial(conv, use_bias=False, dtype=self.dtype)
        norm = partial(
            nn.BatchNorm,
            use_running_average=not train,
            momentum=0.9,
            epsilon=1e-5,
            dtype=self.dtype,
        )

        x = conv(16, (3, 3), name="init_conv")(x)
        if self.dropout_rate > 0.0:
            x = dropout(x, deterministic=not train)
        x = WideResnetGroup(
            conv=conv,
            norm=norm,
            activation=self.activation,
            blocks_per_group=blocks_per_group,
            filters=16 * self.widen_factor,
            strides=(1, 1),
            dropout=self.dropout,
            dropout_rate=self.dropout_rate,
        )(x, train=train)
        x = WideResnetGroup(
            conv=conv,
            norm=norm,
            activation=self.activation,
            blocks_per_group=blocks_per_group,
            filters=32 * self.widen_factor,
            strides=(2, 2),
            dropout=self.dropout,
            dropout_rate=self.dropout_rate,
        )(x, train=train)
        x = WideResnetGroup(
            conv=conv,
            norm=norm,
            activation=self.activation,
            blocks_per_group=blocks_per_group,
            filters=64 * self.widen_factor,
            strides=(2, 2),
            dropout=self.dropout,
            dropout_rate=self.dropout_rate,
        )(x, train=train)
        x = norm()(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        return x


class OutputSubNet(nn.Module):
    """
    Output subnetwork.

    Parameters
    ----------
    output_dim: int
        Output dimension.
    dtype: Any
        Layers' dtype.
    """

    output_dim: int
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        """
        Output subnetwork forward pass.

        Parameters
        ----------
        x: jnp.ndarray
            Subnetwork inputs.
        train: bool
            Whether the call is performed during training.

        Returns
        -------
        jnp.ndarray
            Outputs.
        """
        x = nn.Dense(self.output_dim, dtype=self.dtype)(x)
        return x


[docs]class WideResNet(nn.Module): """ Wide residual network class. Attributes ---------- output_dim: int Output dimension. depth: int Depth of the subnetwork. widen_factor: int Widening factor. dropout_rate: float Dropout rate. dtype: Any Layers' dtype. activation: Callable Activation function. conv: ModuleDef Convolution module. dropout: ModuleDef Dropout module. """ output_dim: int depth: int = 28 widen_factor: int = 10 dropout_rate: float = 0.0 dtype: Any = jnp.float32 activation: Callable = nn.relu conv: ModuleDef = nn.Conv dropout: ModuleDef = nn.Dropout def setup(self): self.dfe_subnet = DeepFeatureExtractorSubNet( depth=self.depth, widen_factor=self.widen_factor, dropout_rate=self.dropout_rate, dtype=self.dtype, activation=self.activation, conv=self.conv, dropout=self.dropout, ) self.output_subnet = OutputSubNet(output_dim=self.output_dim, dtype=self.dtype) def __call__(self, x: Array, train: bool = True) -> jnp.ndarray: """ Forward pass. Parameters ---------- x: Array Input data. train: bool Whether the call is performed during training. Returns ------- jnp.ndarray Outputs. """ x = self.dfe_subnet(x, train) x = self.output_subnet(x, train) return x
WideResNet28_10 = partial(WideResNet, depth=28, widen_factor=10) class WideResNetDeepFeatureExtractorSubNetWithSN( WithSpectralConv2DNorm, DeepFeatureExtractorSubNet ): pass # define the feature extractors with spectral norm WideResNetD28W10DeepFeatureExtractorSubNetWithSN = partial( WideResNetDeepFeatureExtractorSubNetWithSN, depth=28, widen_factor=10 )