Source code for fortuna.model.utils.spectral_norm

"""
The code has been taken from https://github.com/google/edward2/blob/main/edward2/jax/nn/normalization.py
"""
import dataclasses
from typing import (
    Any,
    Callable,
    Mapping,
    Optional,
    Tuple,
    Type,
)

import flax.core
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKeyArray
import numpy as np

from fortuna.typing import (
    Array,
    Shape,
)


def _l2_normalize(x: Array, eps: float = 1e-12) -> Array:
    return x * jax.lax.rsqrt(jnp.maximum(jnp.square(x).sum(), eps))


[docs]class SpectralNormalization(nn.Module): """ Implements spectral normalization for linear layers. See `Spectral Normalization for Generative Adversarial Networks <https://arxiv.org/abs/1802.05957>`_ . Attributes ---------- layer: nn.Module A Flax layer to apply normalization to. iteration: int The number of power iterations to estimate weight matrix's singular value. norm_multiplier: float Multiplicative constant to threshold the normalization. Usually under normalization, the singular value will converge to this value. u_init: Callable[[PRNGKeyArray, Shape, Type], Array] Initializer function for the first left singular vectors of the kernel. v_init: Callable[[PRNGKeyArray, Shape, Type], Array] Initializer function for the first right singular vectors of the kernel. kernel_apply_kwargs: Optional[Mapping[str, Any]] Updated keyword arguments to clone the input layer. The cloned layer represents the linear operator performed by the weight matrix. If not specified, that operator follows SN-GAN implementation (`Takeru M. et al <https://arxiv.org/abs/1802.05957>`_). In particular, for Dense layers the default behavior is equivalent to using a cloned layer with no bias (by specifying `kernel_apply_kwargs=dict(use_bias=False)`). With this customization, we can have the same implementation (inspried by (`Stephan H., 2020 <https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc>`_) for different interpretations of Conv layers. Also see `SpectralNormalizationConv2D` for an example of using this attribute. kernel_name: str Name of the kernel parameter of the input layer. layer_name: Optional[str] Name of the input layer update_singular_value_estimate: Optional[bool] Whether to perform power iterations to update the singular value estimate. """ layer: nn.Module iteration: int = 1 norm_multiplier: float = 0.95 u_init: Callable[[PRNGKeyArray, Shape, Type], Array] = nn.initializers.normal( stddev=0.05 ) v_init: Callable[[PRNGKeyArray, Shape, Type], Array] = nn.initializers.normal( stddev=0.05 ) kernel_apply_kwargs: Optional[Mapping[str, Any]] = None kernel_name: str = "kernel" layer_name: Optional[str] = None update_singular_value_estimate: Optional[bool] = None def _get_singular_vectors( self, initializing: bool, kernel_apply: Callable, in_shape: Shape, dtype: Type ) -> Tuple[nn.Variable, nn.Variable]: if initializing: rng_u = self.make_rng("params") rng_v = self.make_rng("params") # Interpret output shape (not that this does not cost any FLOPs). out_shape = jax.eval_shape( kernel_apply, jax.ShapeDtypeStruct(in_shape, dtype) ).shape else: rng_u = rng_v = out_shape = None u = self.variable("spectral_stats", "u", self.u_init, rng_u, out_shape, dtype) v = self.variable("spectral_stats", "v", self.v_init, rng_v, in_shape, dtype) return u, v @nn.compact def __call__( self, inputs: Array, update_singular_value_estimate: Optional[bool] = None ) -> Array: """ Applies a linear transformation with spectral normalization to the inputs. Parameters ---------- inputs: Array The nd-array to be transformed. update_singular_value_estimate: Optional[bool] Whether to perform power iterations to update the singular value estimate. Returns ------- Array The transformed input. """ update_singular_value_estimate = nn.merge_param( "update_singular_value_estimate", self.update_singular_value_estimate, update_singular_value_estimate, ) layer_name = ( type(self.layer).__name__ if self.layer_name is None else self.layer_name ) params = self.param( layer_name, lambda *args: self.layer.init(*args)["params"], inputs ) w = params[self.kernel_name] if self.kernel_apply_kwargs is None: # By default, we use the implementation in SN-GAN. kernel_apply = lambda x: x @ w.reshape(-1, w.shape[-1]) in_shape = (np.prod(w.shape[:-1]),) else: # Otherwise, we extract the actual kernel transformation in the input # layer. This is useful for Conv2D spectral normalization in # [Farzan F. et al., 2019](https://arxiv.org/abs/1811.07457). kernel_apply = self.layer.clone( **self.kernel_apply_kwargs ).bind( # pylint: disable=not-a-mapping {"params": {self.kernel_name: w}} ) # Compute input shape of the kernel operator. This is correct for all # linear layers on Flax: Dense, Conv, Embed. in_shape = inputs.shape[-w.ndim + 1 : -1] + w.shape[-2:-1] initializing = self.is_mutable_collection("params") u, v = self._get_singular_vectors(initializing, kernel_apply, in_shape, w.dtype) u_hat, v_hat = u.value, v.value u_, kernel_transpose = jax.vjp(kernel_apply, v_hat) if update_singular_value_estimate and not initializing: # Run power iterations using autodiff approach inspired by # (Stephan H., 2020)[https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc]). def scan_body(carry, _): u_hat, v_hat, u_ = carry (v_,) = kernel_transpose(u_hat) v_hat = _l2_normalize(v_) u_ = kernel_apply(v_hat) u_hat = _l2_normalize(u_) return (u_hat, v_hat, u_), None (u_hat, v_hat, u_), _ = jax.lax.scan( scan_body, (u_hat, v_hat, u_), None, length=self.iteration ) u.value, v.value = u_hat, v_hat sigma = jnp.vdot(u_hat, u_) # Bound spectral norm by the `norm_multiplier`. sigma = jnp.maximum(sigma / self.norm_multiplier, 1.0) w_hat = w / jax.lax.stop_gradient(sigma) self.sow("intermediates", "w", w_hat) # Update params. params = flax.core.unfreeze(params) params[self.kernel_name] = w_hat layer_params = flax.core.freeze({"params": params}) return self.layer.apply(layer_params, inputs)
[docs]class SpectralNormalizationConv2D(SpectralNormalization): __doc__ = ( "Implements spectral normalization for Convolutional layers." "See `Generalizable Adversarial Training via Spectral Normalization <https://arxiv.org/abs/1811.07457>`_.\n" + "\n".join(SpectralNormalization.__doc__.split("\n")[1:]) ) kernel_apply_kwargs: Mapping[str, Any] = flax.core.FrozenDict( feature_group_count=1, padding="SAME", use_bias=False )
@dataclasses.dataclass class WithSpectralConv2DNorm: """ Attributes ---------- spectral_norm_iteration: int The number of power iterations to estimate weight matrix's singular value. spectral_norm_bound: float Multiplicative constant to threshold the normalization. Usually under normalization, the singular value will converge to this value. """ spectral_norm_iteration: int = 1.0 spectral_norm_bound: float = 0.95 def spectral_norm(self, layer: Type[nn.Module], train: bool = False) -> Callable: return lambda *args, **kwargs: SpectralNormalizationConv2D( layer=layer(*args, **kwargs), iteration=self.spectral_norm_iteration, norm_multiplier=self.spectral_norm_bound, update_singular_value_estimate=train, ) @dataclasses.dataclass class WithSpectralNorm: """ Attributes ---------- spectral_norm_iteration: int The number of power iterations to estimate weight matrix's singular value. spectral_norm_bound: float Multiplicative constant to threshold the normalization. Usually under normalization, the singular value will converge to this value. """ spectral_norm_iteration: int = 1.0 spectral_norm_bound: float = 0.95 def spectral_norm(self, layer: Type[nn.Module], train: bool = False) -> Callable: return lambda *args, **kwargs: SpectralNormalization( layer=layer(*args, **kwargs), iteration=self.spectral_norm_iteration, norm_multiplier=self.spectral_norm_bound, update_singular_value_estimate=train, )