Model manager#

The model manager is responsible for the orchestration of the forward pass. We support a classification model manager for classification and a regression model manager for regression.

class fortuna.model.model_manager.classification.ClassificationModelManager(model, model_editor=None)[source]#

Classification model manager class. It orchestrates the forward pass of the model in the probabilistic model.

Parameters:

model (nn.Module) – A model describing the deterministic relation between inputs and outputs. The outputs must correspond to the logits of a softmax probability vector. The output dimension must be the same as the number of classes. Let \(x\) be input variables and \(w\) the random model parameters. Then the model is described by a function \(f(w, x)\), where each component of \(f\) corresponds to one of the classes.

apply(params, inputs, mutable=None, train=False, rng=None)[source]#

Apply the models’ forward pass.

Parameters:
  • params (Params) – The random parameters of the probabilistic model.

  • inputs (InputData) – Input data points.

  • mutable (Optional[Mutable]) – The mutable objects used to evaluate the models.

  • train (bool) – Whether the method is called during training.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

The output of the model manager for each input. Mutable objects may also be returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]

init(input_shape, rng=None, **kwargs)[source]#

Initialize random parameters and mutable objects.

Parameters:
  • input_shape (Tuple) – The shape of the input variable.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

Initialized random parameters and mutable objects.

Return type:

Dict[str, FrozenDict]

class fortuna.model.model_manager.regression.RegressionModelManager(model, likelihood_log_variance_model, model_editor=None)[source]#

Regression model manager class. It orchestrates the forward pass of the model in the probabilistic model.

Parameters:
  • model (nn.Module) – A model describing the deterministic relation between inputs and outputs. It characterizes the mean model of the likelihood function. The outputs must belong to the same space as the target variables. Let \(x\) be input variables and \(w\) the random model parameters. Then the model is described by a function \(\mu(w, x)\).

  • likelihood_log_variance_model (nn.Module) – A model characterizing the log-variance of a Gaussian likelihood function. The outputs must belong to the same space as the target variables. Let \(x\) be input variables and \(w\) the random model parameters. Then the model is described by a function \(\log\sigma^2(w, x)\).

apply(params, inputs, mutable=None, train=False, rng=None)[source]#

Apply the models’ forward pass.

Parameters:
  • params (Params) – The random parameters of the probabilistic model.

  • inputs (InputData) – Input data points.

  • mutable (Optional[Mutable]) – The mutable objects used to evaluate the models.

  • train (bool) – Whether the method is called during training.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

The output of the model manager for each input. Mutable objects may also be returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]

init(input_shape, rng=None, **kwargs)[source]#

Initialize random parameters and mutable objects.

Parameters:
  • input_shape (Tuple) – The shape of the input variable.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

Initialized random parameters and mutable objects.

Return type:

Dict[str, FrozenDict]

property rng: RandomNumberGenerator#

Invoke the random number generator object.

Return type:

The random number generator object.

class fortuna.model.model_manager.base.ModelManager(model, model_editor=None)[source]#

Abstract model manager class. It orchestrates the forward pass of the models in the probabilistic model.

abstract apply(params, inputs, mutable=None, train=False, rng=None)[source]#

Apply the models’ forward pass.

Parameters:
  • params (Params) – The random parameters of the probabilistic model.

  • inputs (InputData) – Input data points.

  • mutable (Optional[Mutable]) – The mutable objects used to evaluate the models.

  • train (bool) – Whether the method is called during training.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

The output of the model manager for each input. Mutable objects may also be returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]

abstract init(input_shape, rng=None, **kwargs)[source]#

Initialize random parameters and mutable objects.

Parameters:
  • input_shape (Tuple) – The shape of the input variable.

  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

Initialized random parameters and mutable objects.

Return type:

Dict[str, FrozenDict]

property rng: RandomNumberGenerator#

Invoke the random number generator object.

Return type:

The random number generator object.

class fortuna.model.model_manager.state.ModelManagerState(params, mutable=None)[source]#

A model manager state class.

Parameters:
  • params (Params) – The random parameters of the probabilistic model.

  • mutable (Optional[Mutable]) – The mutable objects used to evaluate the models.

classmethod init_from_dict(d)[source]#

Initialize the model manager state from a dictionary. This dictionary should be like the output of init().

Parameters:

d (Union[Dict, FrozenDict]) – A dictionary like the output of init().

Returns:

An model manager state.

Return type:

ModelManagerState