ResNet#
- class fortuna.model.resnet.ResNet(stage_sizes, block_cls, output_dim, num_filters=64, dtype=<class 'jax.numpy.float32'>, activation=<jax._src.custom_derivatives.custom_jvp object>, conv=<class 'flax.linen.linear.Conv'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module
Deep feature extractor subnetwork.
- stage_sizes#
Sizes for each stage.
- Type:
Sequence[int]
- block_cls#
Block class.
- Type:
ModuleDef
- output_dim#
Output dimension.
- Type:
int
- num_filters#
Number of filters.
- Type:
int
- dtype#
Layers’ dtype.
- Type:
Any
- activation#
Activation function.
- Type:
Callable
- conv#
Convolution module.
- Type:
ModuleDef
- fortuna.model.resnet.ResNet18#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[2, 2, 2, 2], block_cls=<class ‘fortuna.model.resnet.ResNetBlock’>)
- fortuna.model.resnet.ResNet34#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 6, 3], block_cls=<class ‘fortuna.model.resnet.ResNetBlock’>)
- fortuna.model.resnet.ResNet50#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 6, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)
- fortuna.model.resnet.ResNet101#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 4, 23, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)
- fortuna.model.resnet.ResNet152#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 8, 36, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)
- fortuna.model.resnet.ResNet200#
alias of functools.partial(<class ‘fortuna.model.resnet.ResNet’>, stage_sizes=[3, 24, 36, 3], block_cls=<class ‘fortuna.model.resnet.BottleneckResNetBlock’>)