ResNet#

class fortuna.model.resnet.ResNet(stage_sizes, block_cls, output_dim, num_filters=64, dtype=<class 'jax.numpy.float32'>, activation=<jax._src.custom_derivatives.custom_jvp object>, conv=<class 'flax.linen.linear.Conv'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Deep feature extractor subnetwork.

stage_sizes#

Sizes for each stage.

Type:

Sequence[int]

block_cls#

Block class.

Type:

ModuleDef

output_dim#

Output dimension.

Type:

int

num_filters#

Number of filters.

Type:

int

dtype#

Layers’ dtype.

Type:

Any

activation#

Activation function.

Type:

Callable

conv#

Convolution module.

Type:

ModuleDef

fortuna.model.resnet.ResNet18#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[2, 2, 2, 2], block_cls=<class ‘fortuna.model.resnet.ResNetBlock’>)

fortuna.model.resnet.ResNet34#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 6, 3], block_cls=<class ‘fortuna.model.resnet.ResNetBlock’>)

fortuna.model.resnet.ResNet50#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 6, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)

fortuna.model.resnet.ResNet101#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 23, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)

fortuna.model.resnet.ResNet152#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 8, 36, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)

fortuna.model.resnet.ResNet200#

alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 24, 36, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)