Source code for fortuna.prob_model.fit_config.optimizer

from typing import (
    Callable,
    Optional,
    Tuple,
)

import optax

from fortuna.typing import (
    AnyKey,
    Array,
    OptaxOptimizer,
)


[docs]class FitOptimizer: def __init__( self, method: Optional[OptaxOptimizer] = optax.adam(1e-3), n_epochs: int = 100, freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, ): """ An object to configure the optimization in the posterior fitting. Parameters ---------- method: OptaxOptimizer An Optax optimizer. n_epochs: int Maximum number of epochs to run the training for. freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] A callable taking in input a path in the nested dictionary of parameters, as well as the corresponding array of parameters, and returns "trainable" or "freeze", according to whether the corresponding parameter should be optimized or not. """ self.method = method self.n_epochs = n_epochs self.freeze_fun = freeze_fun