Laplace approximation#
- class fortuna.prob_model.posterior.laplace.laplace_approximator.LaplacePosteriorApproximator(tune_prior_log_variance=True)[source]#
Laplace posterior approximator.
- property posterior_method_kwargs: Dict[str, Any]#
- class fortuna.prob_model.posterior.laplace.laplace_posterior.LaplacePosterior(joint, posterior_approximator)[source]#
Bases:
Posterior
Laplace approximation posterior class.
- Parameters:
joint (Joint) – A joint distribution object.
posterior_approximator (LaplacePosteriorApproximator) – A Laplace posterior approximator.
- fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#
Fit the posterior distribution. A posterior state will be internally stored.
- Parameters:
train_data_loader (DataLoader) – Training data loader.
val_data_loader (Optional[DataLoader]) – Validation data loader.
fit_config (FitConfig) – A configuration object.
- Returns:
A status including metrics describing the fitting process.
- Return type:
Status
- prior_log_variance_tuning(val_data_loader, n_posterior_samples=10, mode='cv', min_prior_log_var=-3, max_prior_log_var=3, grid_size=20, distribute=False)[source]#
- Return type:
Array
- class fortuna.prob_model.posterior.laplace.laplace_state.LaplaceState(step, apply_fn, params, tx, opt_state, encoded_name=(76, 97, 112, 108, 97, 99, 101, 83, 116, 97, 116, 101), frozen_params=None, dynamic_scale=None, mutable=None, calib_params=None, calib_mutable=None, grad_accumulated=None, prior_log_var=0.0, _encoded_which_params=None)[source]#
Bases:
PosteriorState
- prior_log_var#
Prior log-variance value.
- Type:
float
- encoded_name#
Laplace state name encoded as an array.
- Type:
jnp.ndarray
- classmethod convert_from_map_state(map_state, hess_lik_diag, prior_log_var, which_params)[source]#
Convert a MAP state into a Laplace state.
- Parameters:
map_state (MAPState) – A MAP state.
hess_lik_diag (Union[Params, Tuple[Params, ...]]) – Diagonal of the approximated Hessian of the likelihood.
prior_log_var (float) – Prior log-variance value. If None, initialize it to 100.
which_params (Tuple[List[AnyKey], ...]) – Sequences of keys pointing to the parameters over which std is defined. If which_params is None, std must be defined for all parameters.
- Returns:
A Laplace state instance.
- Return type:
- classmethod init(params, mutable=None, optimizer=None, calib_params=None, calib_mutable=None, grad_accumulated=None, dynamic_scale=None, **kwargs)#
Initialize a posterior distribution state.
- Parameters:
params (Params) – The parameters characterizing an approximation of the posterior distribution.
optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the posterior state.
mutable (Optional[Mutable]) – The mutable objects characterizing an approximation of the posterior distribution.
calib_params (Optional[CalibParams]) – The parameters objects characterizing an approximation of the posterior distribution.
calib_mutable (Optional[CalibMutable]) – The calibration mutable objects characterizing an approximation of the posterior distribution.
grad_accumulated (Optional[jnp.ndarray]) – The gradients accumulated in consecutive training steps (used only when gradient_accumulation_steps > 1).
dynamic_scale (Optional[dynamic_scale.DynamicScale]) – Dynamic loss scaling for mixed precision gradients.
- Returns:
A posterior distribution state.
- Return type:
Any
- classmethod init_from_dict(d, optimizer=None, **kwargs)#
Initialize a posterior distribution state from a dictionary.
- Parameters:
d (Dict) – A dictionary including attributes of the posterior state.
optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the posterior state.
- Returns:
A posterior state.
- Return type: