Stochastic Gradient Markov Chain Monte Carlo (SG-MCMC)#
SG-MCMC procedures approximate the posterior as a steady-state distribution of a Monte Carlo Markov chain, that utilizes noisy estimates of the gradient computed on minibatches of data.
Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)#
SGHMC [Chen T. et al., 2014] is a popular MCMC algorithm that uses stochastic gradient estimates to scale to large datasets.
- class fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator.SGHMCPosteriorApproximator(n_samples=10, n_thinning=1, burnin_length=1000, momentum_decay=0.01, step_schedule=1e-5, preconditioner=identity_preconditioner())[source]#
SGHMC posterior approximator. It is responsible to define how the posterior distribution is approximated.
The total number of available posterior samples depends on the number of training steps, burnin_length, and n_thinning parameters:
n_available_samples = (n_training_steps - burnin_length) % n_thinning
Setting the desired number of samples n_samples larger than n_available_samples will result in an exception.
- Parameters:
n_samples (int) – The desired number of the posterior samples.
n_thinning (int) – If n_thinning > 1, keep only each n_thinning sample during the sampling phase.
burnin_length (int) – Length of the initial burn-in phase, in steps.
momentum_decay (float) – The “friction” term that counters the noise of stochastic gradient estimates. Setting this argument to zero recovers the overamped Langevin dynamics.
step_schedule (Union[StepSchedule, float]) – Either a constant float step size or a schedule function.
preconditioner (Preconditioner) – A Preconditioner instance that preconditions the approximator with information about the posterior distribution, if available.
- property posterior_method_kwargs: Dict[str, Any]#
- class fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior.SGHMCPosterior(joint, posterior_approximator)[source]#
Bases:
SGMCMCPosterior
Stochastic Gradient Hamiltonian Monte Carlo approximate posterior class.
- Parameters:
joint (Joint) – A Joint distribution object.
posterior_approximator (SGHMCPosteriorApproximator) – A SGHMC posterior approximator.
- fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#
Fit the posterior distribution. A posterior state will be internally stored.
- Parameters:
train_data_loader (DataLoader) – Training data loader.
val_data_loader (Optional[DataLoader]) – Validation data loader.
fit_config (FitConfig) – A configuration object.
- Returns:
A status including metrics describing the fitting process.
- Return type:
Status
- class fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state.SGHMCState(step, apply_fn, params, tx, opt_state, encoded_name=(83, 71, 72, 77, 67, 83, 116, 97, 116, 101), frozen_params=None, dynamic_scale=None, mutable=None, calib_params=None, calib_mutable=None, grad_accumulated=None, _encoded_which_params=None)[source]#
Bases:
PosteriorState
- encoded_name#
SGHMC state name encoded as an array.
- Type:
jnp.ndarray
- classmethod convert_from_map_state(map_state, optimizer, which_params)[source]#
Convert a MAP state into an SGHMC state.
- Parameters:
map_state (MAPState) – A MAP posterior state.
optimizer (OptaxOptimizer) – An Optax optimizer.
which_params (Tuple[List[AnyKey], ...]) – Sequences of keys pointing to the stochastic parameters.
- Returns:
An SGHMC state.
- Return type:
- classmethod init(params, mutable=None, optimizer=None, calib_params=None, calib_mutable=None, grad_accumulated=None, dynamic_scale=None, **kwargs)#
Initialize a posterior distribution state.
- Parameters:
params (Params) – The parameters characterizing an approximation of the posterior distribution.
optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the posterior state.
mutable (Optional[Mutable]) – The mutable objects characterizing an approximation of the posterior distribution.
calib_params (Optional[CalibParams]) – The parameters objects characterizing an approximation of the posterior distribution.
calib_mutable (Optional[CalibMutable]) – The calibration mutable objects characterizing an approximation of the posterior distribution.
grad_accumulated (Optional[jnp.ndarray]) – The gradients accumulated in consecutive training steps (used only when gradient_accumulation_steps > 1).
dynamic_scale (Optional[dynamic_scale.DynamicScale]) – Dynamic loss scaling for mixed precision gradients.
- Returns:
A posterior distribution state.
- Return type:
Any
- classmethod init_from_dict(d, optimizer=None, **kwargs)#
Initialize a posterior distribution state from a dictionary.
- Parameters:
d (Dict) – A dictionary including attributes of the posterior state.
optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the posterior state.
- Returns:
A posterior state.
- Return type:
Cyclical Stochastic Gradient Langevin Dynamics (CyclicalSGLD)#
Cyclical SGLD method [Zhang R. et al., 2019] is a simple and automatic procedure that adapts the cyclical cosine stepsize schedule, and alternates between exploration and sampling stages to better explore the multimodal posteriors for deep neural networks.
- class fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator.CyclicalSGLDPosteriorApproximator(n_samples=10, n_thinning=1, cycle_length=1000, init_step_size=1e-5, exploration_ratio=0.25, preconditioner=identity_preconditioner())[source]#
Cyclical SGLD posterior approximator. It is responsible to define how the posterior distribution is approximated.
The total number of available posterior samples depends on the number of training steps, burnin_length and n_thinning parameters, as well as cycle_length and exploration_ratio. In case if the number of training steps divides evenly by the cycle length, it can be calculated as follows:
n_cycles = n_training_steps % cycle_length n_sampling_steps = (n_cycles * cycle_length) * (1 - exploration_ratio) n_available_samples = n_sampling_steps % n_thinning
Setting the desired number of samples n_samples larger than n_available_samples will result in an exception.
- Parameters:
n_samples (int) – The desired number of the posterior samples.
n_thinning (int) – If n_thinning > 1, keep only each n_thinning sample during the sampling phase.
cycle_length (int) – The length of each exploration/sampling cycle, in steps.
init_step_size (float) – The initial step size.
exploration_ratio (float) – The fraction of steps to allocate to the mode exploration phase.
preconditioner (Preconditioner) – A Preconditioner instance that preconditions the approximator with information about the posterior distribution, if available.
- property posterior_method_kwargs: Dict[str, Any]#
- class fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior.CyclicalSGLDPosterior(joint, posterior_approximator)[source]#
Bases:
SGMCMCPosterior
Cyclical Stochastic Gradient Langevin Dynamics (SGLD) approximate posterior class.
- Parameters:
joint (Joint) – A Joint distribution object.
posterior_approximator (CyclicalSGLDPosteriorApproximator) – A cyclical SGLD posterior approximator.
- fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#
Fit the posterior distribution. A posterior state will be internally stored.
- Parameters:
train_data_loader (DataLoader) – Training data loader.
val_data_loader (Optional[DataLoader]) – Validation data loader.
fit_config (FitConfig) – A configuration object.
- Returns:
A status including metrics describing the fitting process.
- Return type:
Status
- class fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state.CyclicalSGLDState(step, apply_fn, params, tx, opt_state, encoded_name=(67, 121, 99, 108, 105, 99, 97, 108, 83, 71, 76, 68, 83, 116, 97, 116, 101), frozen_params=None, dynamic_scale=None, mutable=None, calib_params=None, calib_mutable=None, grad_accumulated=None, _encoded_which_params=None)[source]#
Bases:
PosteriorState
- encoded_name#
CyclicalSGLDState state name encoded as an array.
- Type:
jnp.ndarray
- classmethod convert_from_map_state(map_state, optimizer, which_params)[source]#
Convert a MAP state into an CyclicalSGLDState state.
- Parameters:
map_state (MAPState) – A MAP posterior state.
optimizer (OptaxOptimizer) – An Optax optimizer.
which_params (Tuple[List[AnyKey], ...]) – Sequences of keys pointing to the stochastic parameters.
- Returns:
An Cyclical SGLD state.
- Return type:
- classmethod init(params, mutable=None, optimizer=None, calib_params=None, calib_mutable=None, grad_accumulated=None, dynamic_scale=None, **kwargs)#
Initialize a posterior distribution state.
- Parameters:
params (Params) – The parameters characterizing an approximation of the posterior distribution.
optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the posterior state.
mutable (Optional[Mutable]) – The mutable objects characterizing an approximation of the posterior distribution.
calib_params (Optional[CalibParams]) – The parameters objects characterizing an approximation of the posterior distribution.
calib_mutable (Optional[CalibMutable]) – The calibration mutable objects characterizing an approximation of the posterior distribution.
grad_accumulated (Optional[jnp.ndarray]) – The gradients accumulated in consecutive training steps (used only when gradient_accumulation_steps > 1).
dynamic_scale (Optional[dynamic_scale.DynamicScale]) – Dynamic loss scaling for mixed precision gradients.
- Returns:
A posterior distribution state.
- Return type:
Any
- classmethod init_from_dict(d, optimizer=None, **kwargs)#
Initialize a posterior distribution state from a dictionary.
- Parameters:
d (Dict) – A dictionary including attributes of the posterior state.
optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the posterior state.
- Returns:
A posterior state.
- Return type:
Step schedules#
Fortuna supports various step schedulers for SG-MCMC
algorithms. StepSchedule
is a function that takes step count as an input and returns float step
size as an output.
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.constant_schedule(init_step_size)[source]#
Create a constant step schedule.
- Parameters:
init_step_size (float) – The step size.
- Returns:
schedule_fn
- Return type:
StepSchedule
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.constant_schedule_with_cosine_burnin(init_step_size, final_step_size, burnin_steps)[source]#
Create a constant schedule with cosine burn-in.
- Parameters:
init_step_size (float) – The initial step size.
final_step_size (float) – The desired final step size.
burnin_steps (int) – The length of burn-in, in steps.
- Returns:
schedule_fn
- Return type:
StepSchedule
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.cosine_schedule(init_step_size, total_steps)[source]#
Create a cosine step schedule.
- Parameters:
init_step_size (float) – The initial step size.
total_steps (int) – The cycle length, in steps.
- Returns:
schedule_fn
- Return type:
StepSchedule
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.cyclical_cosine_schedule_with_const_burnin(init_step_size, burnin_steps, cycle_length)[source]#
Create a cyclical cosine schedule with constant burn-in.
- Parameters:
init_step_size (float) – The initial step size.
burnin_steps (int) – The length of burn-in, in steps.
cycle_length (int) – The length of the cosine cycle, in steps.
- Returns:
schedule_fn
- Return type:
StepSchedule
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.polynomial_schedule(a=1.0, b=1.0, gamma=0.55)[source]#
Create a polynomial step schedule.
- Parameters:
a (float) – Scale of all step sizes.
b (float) – The stabilization constant.
gamma (float) – The decay rate \(\gamma \in (0.5, 1.0]\).
- Returns:
schedule_fn
- Return type:
StepSchedule
Preconditioners#
Fortuna provides implementations of preconditioners to improve samplers efficacy.
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner.identity_preconditioner()[source]#
Create an instance of no-op identity preconditioner.
- Returns:
preconditioner – An instance of identity preconditioner.
- Return type:
Preconditioner
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner.rmsprop_preconditioner(running_average_factor=0.99, eps=1.0e-7)[source]#
Create an instance of the adaptive RMSProp preconditioner.
- Parameters:
running_average_factor (float) – The decay factor for the squared gradients moving average.
eps (float) – \(\epsilon\) constant for numerical stability.
- Returns:
- preconditioner: Preconditioner
An instance of RMSProp preconditioner.
Diagnostics#
The library includes toolings necessary for diagnostics of the convergence of SG-MCMC sampling procedures.
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic.effective_sample_size(samples, filter_threshold=0.0)[source]#
Estimate the effective sample size of a sequence.
For a sequence of length \(N\), the effective sample size is defined as
\(ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]\)
where \(R_k\) is the auto-correlation sequence, \(R_k := Cov{X_1, X_{1+k}} / Var{X_1}\)
- Parameters:
samples (List[PyTree]) – The list of PyTree, each representing an MCMC sample.
filter_threshold (Optional[float]) – The cut-off value to truncate the sequence at the first index where the estimated auto-correlation is less than the threshold.
- Return type:
Any
- Returns:
- ESS: PyTree
Parameter-wise estimates of the effective sample size.
- fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic.kernel_stein_discrepancy_imq(samples, grads, c=1.0, beta=-0.5)[source]#
Kernel Stein Discrepancy with the Inverse Multiquadric (IMQ) kernel.
See Gorham J. and Mackey L., 2017 for more details.
- Parameters:
samples (List[PyTree]) – The list of PyTree, each representing an MCMC sample.
grads (List[PyTree]) – The list of the corresponding density gradients.
c (float) – \(c > 0\) kernel bias hyperparameter.
beta (float) – \(beta < 0\) kernel exponent hyperparameter.
- Return type:
float
- Returns:
- ksd_img: float
The kernel Stein discrepancy value.