Deep Ensemble#
- class fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator.DeepEnsemblePosteriorApproximator(ensemble_size=5)[source]#
Deep ensemble posterior approximator. It is responsible to define how the posterior distribution is approximated.
- Parameters:
ensemble_size (int) – The size of the ensemble.
- property posterior_method_kwargs: Dict[str, Any]#
- class fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior.DeepEnsemblePosterior(joint, posterior_approximator)[source]#
Bases:
Posterior
Deep ensemble approximate posterior class.
- Parameters:
joint (Joint) – Joint distribution.
posterior_approximator (DeepEnsemble) – Deep ensemble posterior approximator.
- fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#
Fit the posterior distribution. A posterior state will be internally stored.
- Parameters:
train_data_loader (DataLoader) – Training data loader.
val_data_loader (Optional[DataLoader]) – Validation data loader.
fit_config (FitConfig) – A configuration object.
- Returns:
A status including metrics describing the fitting process.
- Return type:
Status
- load_state(checkpoint_dir)[source]#
Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the current probabilistic model.
- Parameters:
checkpoint_path (Path) – Path to checkpoint file or directory to restore.
- Return type:
None
- sample(rng=None, **kwargs)[source]#
Sample from the posterior distribution.
- Parameters:
rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.
- Returns:
A sample from the posterior distribution.
- Return type:
- save_state(checkpoint_dir, keep_top_n_checkpoints=1)[source]#
Save the state of the posterior distribution to a checkpoint directory.
- Parameters:
checkpoint_path (Path) – Path to checkpoint file or directory to restore.
keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.
- Return type:
None