WideResNet#

class fortuna.model.wideresnet.WideResNet(output_dim, depth=28, widen_factor=10, dropout_rate=0.0, dtype=<class 'jax.numpy.float32'>, activation=<jax._src.custom_derivatives.custom_jvp object>, conv=<class 'flax.linen.linear.Conv'>, dropout=<class 'flax.linen.stochastic.Dropout'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Wide residual network class.

output_dim#

Output dimension.

Type:

int

depth#

Depth of the subnetwork.

Type:

int

widen_factor#

Widening factor.

Type:

int

dropout_rate#

Dropout rate.

Type:

float

dtype#

Layers’ dtype.

Type:

Any

activation#

Activation function.

Type:

Callable

conv#

Convolution module.

Type:

ModuleDef

dropout#

Dropout module.

Type:

ModuleDef

fortuna.model.wideresnet.WideResNet28_10#

alias of functools.partial(<class ‘fortuna.model.wideresnet.WideResNet’>, depth=28, widen_factor=10)