Deep Ensemble#

class fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator.DeepEnsemblePosteriorApproximator(ensemble_size=5)[source]#

Deep ensemble posterior approximator. It is responsible to define how the posterior distribution is approximated.

Parameters:

ensemble_size (int) – The size of the ensemble.

property posterior_method_kwargs: Dict[str, Any]#
class fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior.DeepEnsemblePosterior(joint, posterior_approximator)[source]#

Bases: Posterior

Deep ensemble approximate posterior class.

Parameters:
  • joint (Joint) – Joint distribution.

  • posterior_approximator (DeepEnsemble) – Deep ensemble posterior approximator.

fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#

Fit the posterior distribution. A posterior state will be internally stored.

Parameters:
  • train_data_loader (DataLoader) – Training data loader.

  • val_data_loader (Optional[DataLoader]) – Validation data loader.

  • fit_config (FitConfig) – A configuration object.

Returns:

A status including metrics describing the fitting process.

Return type:

Status

load_state(checkpoint_dir)[source]#

Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the current probabilistic model.

Parameters:

checkpoint_path (Path) – Path to checkpoint file or directory to restore.

Return type:

None

sample(rng=None, **kwargs)[source]#

Sample from the posterior distribution.

Parameters:

rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

Returns:

A sample from the posterior distribution.

Return type:

JointState

save_state(checkpoint_dir, keep_top_n_checkpoints=1)[source]#

Save the state of the posterior distribution to a checkpoint directory.

Parameters:
  • checkpoint_path (Path) – Path to checkpoint file or directory to restore.

  • keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.

Return type:

None