WideResNet#
- class fortuna.model.wideresnet.WideResNet(output_dim, depth=28, widen_factor=10, dropout_rate=0.0, dtype=<class 'jax.numpy.float32'>, activation=<jax._src.custom_derivatives.custom_jvp object>, conv=<class 'flax.linen.linear.Conv'>, dropout=<class 'flax.linen.stochastic.Dropout'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module
Wide residual network class.
- output_dim#
Output dimension.
- Type:
int
- depth#
Depth of the subnetwork.
- Type:
int
- widen_factor#
Widening factor.
- Type:
int
- dropout_rate#
Dropout rate.
- Type:
float
- dtype#
Layers’ dtype.
- Type:
Any
- activation#
Activation function.
- Type:
Callable
- conv#
Convolution module.
- Type:
ModuleDef
- dropout#
Dropout module.
- Type:
ModuleDef
- fortuna.model.wideresnet.WideResNet28_10#
alias of functools.partial(<class ‘fortuna.model.wideresnet.WideResNet’>, depth=28, widen_factor=10)