Posterior#
The posterior distribution of the model parameters given the training data and the calibration parameters. We support several posterior approximations:
- class fortuna.prob_model.posterior.base.Posterior(joint, posterior_approximator)[source]#
Posterior distribution class. This refers to \(p(w|\mathcal{D}, \phi)\), where \(w\) are the random model parameters, \(\mathcal{D}\) is a training data set and \(\phi\) are calibration parameters.
- Parameters:
joint (Joint) – A joint distribution object.
posterior_approximator (PosteriorApproximator) – A posterior approximator.
- abstract fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), **kwargs)[source]#
Fit the posterior distribution. A posterior state will be internally stored.
- Parameters:
train_data_loader (DataLoader) – Training data loader.
val_data_loader (Optional[DataLoader]) – Validation data loader.
fit_config (FitConfig) – A configuration object.
- Returns:
A status including metrics describing the fitting process.
- Return type:
Status
- load_state(checkpoint_path)[source]#
Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the current probabilistic model.
- Parameters:
checkpoint_path (Path) – Path to checkpoint file or directory to restore.
- Return type:
None
- abstract sample(rng=None, *args, **kwargs)[source]#
Sample from the posterior distribution.
- Parameters:
rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.
- Returns:
A sample from the posterior distribution.
- Return type:
- save_state(checkpoint_path, keep_top_n_checkpoints=1)[source]#
Save the state of the posterior distribution to a checkpoint directory.
- Parameters:
checkpoint_path (Path) – Path to checkpoint file or directory to restore.
keep_top_n_checkpoints (int) – Number of past checkpoint files to keep.
- Return type:
None
- class fortuna.prob_model.posterior.base.PosteriorApproximator[source]#
A posterior approximator abstract class.
- property posterior_method_kwargs: Dict[str, Any]#
- class fortuna.prob_model.posterior.state.PosteriorState(step, apply_fn, params, tx, opt_state, encoded_name=(80, 111, 115, 116, 101, 114, 105, 111, 114, 83, 116, 97, 116, 101), frozen_params=None, dynamic_scale=None, mutable=None, calib_params=None, calib_mutable=None, grad_accumulated=None)[source]#
A posterior distribution state. This includes all the parameters and mutable objects that characterize an approximation of the posterior distribution.
- dynamic_scale: Optional[dynamic_scale.DynamicScale] = None#
- grad_accumulated: Optional[jnp.ndarray] = None#
- classmethod init(params, mutable=None, optimizer=None, calib_params=None, calib_mutable=None, grad_accumulated=None, dynamic_scale=None, **kwargs)[source]#
Initialize a posterior distribution state.
- Parameters:
params (Params) – The parameters characterizing an approximation of the posterior distribution.
optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the posterior state.
mutable (Optional[Mutable]) – The mutable objects characterizing an approximation of the posterior distribution.
calib_params (Optional[CalibParams]) – The parameters objects characterizing an approximation of the posterior distribution.
calib_mutable (Optional[CalibMutable]) – The calibration mutable objects characterizing an approximation of the posterior distribution.
grad_accumulated (Optional[jnp.ndarray]) – The gradients accumulated in consecutive training steps (used only when gradient_accumulation_steps > 1).
dynamic_scale (Optional[dynamic_scale.DynamicScale]) – Dynamic loss scaling for mixed precision gradients.
- Returns:
A posterior distribution state.
- Return type:
Any
- classmethod init_from_dict(d, optimizer=None, **kwargs)[source]#
Initialize a posterior distribution state from a dictionary.
- Parameters:
d (Dict) – A dictionary including attributes of the posterior state.
optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the posterior state.
- Returns:
A posterior state.
- Return type: