Multi-Layer Perceptron (MLP)

class fortuna.model.mlp.MLP(output_dim, widths=(30, 30), activations=(<jax._src.custom_derivatives.custom_jvp object>, <jax._src.custom_derivatives.custom_jvp object>), dropout=<class 'flax.linen.stochastic.Dropout'>, dropout_rate=0.0, dense=<class 'flax.linen.linear.Dense'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A multi-layer perceptron (MLP).

output_dim

The output model dimension.

Type:

int

widths

The number of units of each hidden layer. Default: (30, 30).

Type:

Tuple[int]

activations

The activation functions after each hidden layer. Default: (flax.linen.relu, flax.linen.relu).

Type:

Tuple[Callable[[Array], Array]]

dropout

Dropout module.

Type:

ModuleDef

dropout_rate

Dropout rate.

Type:

float

dense

Dense module.

Type:

ModuleDef