SWAG#

class fortuna.prob_model.posterior.swag.swag_approximator.SWAGPosteriorApproximator(rank=5)[source]#

SWAG posterior approximator. It is responsible to define how the posterior distribution is approximated.

Parameters:

rank (int) – SWAG approximates the posterior with a Gaussian distribution. The Gaussian’s covariance matrix is formed by a diagonal matrix, and a low-rank empirical approximation. This argument defines the rank of the low-rank empirical covariance approximation. It must be at least 2.

property posterior_method_kwargs: Dict[str, Any]#
class fortuna.prob_model.posterior.swag.swag_posterior.SWAGPosterior(joint, posterior_approximator)[source]#

Bases: Posterior

SWAG approximate posterior class.

Parameters:
fit(train_data_loader, val_data_loader=None, fit_config=FitConfig(), map_fit_config=None, **kwargs)[source]#

Fit the posterior distribution. A posterior state will be internally stored.

Parameters:
  • train_data_loader (DataLoader) – Training data loader.

  • val_data_loader (Optional[DataLoader]) – Validation data loader.

  • fit_config (FitConfig) – A configuration object.

Returns:

A status including metrics describing the fitting process.

Return type:

Status

sample(rng=None, inputs_loader=None, inputs=None, **kwargs)[source]#

Sample from the posterior distribution.

Parameters:
  • rng (Optional[PRNGKeyArray]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • inputs_loader (Optional[InputsLoader]) – Input data loader. This or inputs is required if the posterior state includes mutable objects.

  • inputs (Optional[Array]) – Input variables. This or inputs_loader is required if the posterior state includes mutable objects.

Returns:

A sample from the posterior distribution.

Return type:

JointState

class fortuna.prob_model.posterior.swag.swag_state.SWAGState(step, apply_fn, params, tx, opt_state, encoded_name=(83, 87, 65, 71, 83, 116, 97, 116, 101), frozen_params=None, dynamic_scale=None, mutable=None, calib_params=None, calib_mutable=None, grad_accumulated=None, mean=None, std=None, dev=None, _encoded_which_params=None)[source]#

Bases: PosteriorState

encoded_name#

SWAG state name encoded as an array.

Type:

jnp.ndarray

mean#

Mean of the posterior approximation.

Type:

Optional[jnp.ndarray]

std#

Diagonal standard deviation of the posterior approximation.

Type:

Optional[jnp.ndarray]

dev#

Deviation term of the covariance matrix of the posterior approximation.

Type:

Optional[jnp.ndarray]

classmethod convert_from_map_state(map_state, optimizer)[source]#

Convert a MAP state into a SWAG state.

Parameters:
  • map_state (MAPState) – A MAP posterior state.

  • optimizer (OptaxOptimizer) – An Optax optimizer.

Returns:

A SWAG state.

Return type:

SWAGState

classmethod init(params, mutable=None, optimizer=None, calib_params=None, calib_mutable=None, grad_accumulated=None, dynamic_scale=None, **kwargs)#

Initialize a posterior distribution state.

Parameters:
  • params (Params) – The parameters characterizing an approximation of the posterior distribution.

  • optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the posterior state.

  • mutable (Optional[Mutable]) – The mutable objects characterizing an approximation of the posterior distribution.

  • calib_params (Optional[CalibParams]) – The parameters objects characterizing an approximation of the posterior distribution.

  • calib_mutable (Optional[CalibMutable]) – The calibration mutable objects characterizing an approximation of the posterior distribution.

  • grad_accumulated (Optional[jnp.ndarray]) – The gradients accumulated in consecutive training steps (used only when gradient_accumulation_steps > 1).

  • dynamic_scale (Optional[dynamic_scale.DynamicScale]) – Dynamic loss scaling for mixed precision gradients.

Returns:

A posterior distribution state.

Return type:

Any

classmethod init_from_dict(d, optimizer=None, **kwargs)#

Initialize a posterior distribution state from a dictionary.

Parameters:
  • d (Dict) – A dictionary including attributes of the posterior state.

  • optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the posterior state.

Returns:

A posterior state.

Return type:

PosteriorState

update(variables)[source]#

Update the SWAG state.

Parameters:

variables (Dict[str, Any]) – The attributes to update and their values.

Returns:

Updated SWAG state.

Return type:

SWAGState