Predictive distribution

The predictive distribution is the component of the probabilistic model responsible for the computation of predictive statistics. We support a classification predictive for classification and a regression predictive for regression. Please find their references below.

class fortuna.prob_model.predictive.classification.ClassificationPredictive(posterior)[source]

Classification predictive distribution class.

Parameters:

posterior (Posterior) – A posterior distribution object.

aleatoric_entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]

Estimate the predictive aleatoric entropy, that is

\[-\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive aleatoric entropy for each input.

Return type:

jnp.ndarray

aleatoric_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]

Estimate the predictive aleatoric variance of the one-hot encoded target variable, that is

\[\text{Var}_{W|\mathcal{D}}[\mathbb{E}_{\tilde{Y}|W, x}[\tilde{Y}]],\]
where:
  • \(x\) is an observed input variable;

  • \(\tilde{Y}\) is a one-hot encoded random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive aleatoric variance for each input.

Return type:

jnp.ndarray

conformal_set(train_data_loader, test_inputs_loader, n_posterior_samples=30, error=0.05, rng=None, distribute=True, return_ess=False)[source]

Estimate conformal sets for the target variable.

Parameters:
  • train_data_loader (DataLoader) – A training data loader.

  • test_inputs_loader (InputsLoader) – A test inputs loader.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • error (float) – The set error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of confidence.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

  • return_ess (bool) – Whether to compute effective sample size of importance weights or not.

Returns:

A list of conformal sets for each test input.

Return type:

List[List[int]]

credible_set(inputs_loader, n_posterior_samples=30, error=0.05, rng=None, distribute=True)[source]

Estimate credible sets for the target variable. This is done by sorting the class probabilities in descending order and including classes until the sum > 1-error.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of posterior samples to draw for each input.

  • error (float) – The set error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of credibility.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

A credibility set for each of the inputs.

Return type:

jnp.ndarray

ensemble_log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)

Compute the log-likelihood at each posterior sample, that is

\[\log p(y|x, theta^{(i)}),\]
where:
  • \(x\) is an observed input variable;

  • \(y\) is an observed target variable;

  • \(theta^{(i)}\) is a sample from the posterior.

Parameters:
  • data_loader (DataLoader) – A data loader.

  • n_posterior_samples (int) – Number of posterior samples to draw in order to compute the log -ikelihood. that would be produced using the posterior distribution state.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An array of log-likelihood values at each posterior sample for each data point.

Return type:

jnp.ndarray

entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]

Estimate the predictive entropy, that is

\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive entropy for each input.

Return type:

jnp.ndarray

epistemic_entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]

Estimate the predictive epistemic entropy, that is

\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] + \mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Note that the epistemic entropy above is defined as the difference between the predictive entropy and the aleatoric predictive entropy.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive epistemic entropy for each input.

Return type:

jnp.ndarray

epistemic_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)[source]

Estimate the predictive epistemic variance of the one-hot encoded target variable, that is

\[\mathbb{E}_{W|D}[\text{Var}_{\tilde{Y}|W, x}[\tilde{Y}]],\]
where:
  • \(x\) is an observed input variable;

  • \(\tilde{Y}\) is a one-hot encoded random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive epistemic variance for each input.

Return type:

jnp.ndarray

log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)

Estimate the predictive log-probability density function (a.k.a. log-pdf), that is

\[\log p(y|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(y\) is an observed target variable;

  • \(\mathcal{D}\) is the observed training data set.

Parameters:
  • data_loader (DataLoader) – A data loader.

  • n_posterior_samples (int) – Number of posterior samples to draw in order to approximate the predictive log-pdf. that would be produced using the posterior distribution state.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive log-pdf for each data point.

Return type:

jnp.ndarray

mean(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]

Estimate the predictive mean of the one-hot encoded target variable, that is

\[\mathbb{E}_{\tilde{Y}|x, \mathcal{D}}[\tilde{Y}],\]
where:
  • \(x\) is an observed input variable;

  • \(\tilde{Y}\) is a one-hot encoded random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive mean for each input.

Return type:

jnp.ndarray

mode(inputs_loader, n_posterior_samples=30, means=None, rng=None, distribute=True)[source]

Estimate the predictive mode of the target variable, that is

\[\text{argmax}_y\ p(y|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(y\) is the target variable to optimize upon.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • means (Optional[jnp.ndarray] = None) – An estimate of the predictive mean.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive mode for each input.

Return type:

jnp.ndarray

property rng: RandomNumberGenerator

Invoke the random number generator object.

Return type:

The random number generator object.

sample(inputs_loader, n_target_samples=1, return_aux=None, rng=None, distribute=True, **kwargs)

Sample from an approximation of the predictive distribution for each input data point, that is

\[y^{(i)}\sim p(\cdot|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(y^{(i)}\) is a sample of the target variable for the input \(x\).

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to sample for each input data point.

  • return_aux (Optional[List[str]]) – Return auxiliary objects. We currently support ‘outputs’.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

Samples for each input data point. Optionally, an auxiliary object is returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]

sample_calibrated_outputs(inputs_loader, n_output_samples=1, rng=None, distribute=True)

Sample parameters from the posterior distribution state and compute calibrated outputs.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_output_samples (int) – Number of output samples to draw for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

Samples of calibrated outputs.

Return type:

jnp.ndarray

std(inputs_loader, n_posterior_samples=30, variances=None, rng=None, distribute=True)[source]

Estimate the predictive standard deviation of the one-hot encoded target variable, that is

\[\sqrt{\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}]},\]
where:
  • \(x\) is an observed input variable;

  • \(\tilde{Y}\) is a one-hot encoded random target variable;

  • \(\mathcal{D}\) is the observed training data set.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • variances (Optional[jnp.ndarray]) – An estimate of the predictive variance.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive standard deviation for each input.

Return type:

jnp.ndarray

variance(inputs_loader, n_posterior_samples=30, aleatoric_variances=None, epistemic_variances=None, rng=None, distribute=True)[source]

Estimate the predictive variance of the one-hot encoded target variable, that is

\[\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}],\]
where:
  • \(x\) is an observed input variable;

  • \(\tilde{Y}\) is a one-hot encoded random target variable;

  • \(\mathcal{D}\) is the observed training data set.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • aleatoric_variances (Optional[jnp.ndarray]) – An estimate of the aleatoric predictive variance.

  • epistemic_variances (Optional[jnp.ndarray]) – An estimate of the epistemic predictive variance.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive variance for each input.

Return type:

jnp.ndarray

class fortuna.prob_model.predictive.regression.RegressionPredictive(posterior)[source]

Regression predictive distribution class.

Parameters:

posterior (Posterior) – A posterior distribution object.

aleatoric_entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]

Estimate the predictive aleatoric entropy, that is

\[-\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to draw for each input.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive aleatoric entropy for each input.

Return type:

jnp.ndarray

aleatoric_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)

Estimate the predictive aleatoric variance of the target variable, that is

\[\text{Var}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[Y]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive aleatoric variance for each input.

Return type:

jnp.ndarray

credible_interval(inputs_loader, n_target_samples=30, error=0.05, interval_type='two-tailed', rng=None, distribute=True)[source]

Estimate credible intervals for the target variable. This is supported only if the target variable is scalar.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to draw for each input.

  • error (float) – The interval error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of credibility.

  • interval_type (str) – The interval type. We support “two-tailed” (default), “right-tailed” and “left-tailed”.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

A credibility interval for each of the inputs.

Return type:

jnp.ndarray

ensemble_log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)

Compute the log-likelihood at each posterior sample, that is

\[\log p(y|x, theta^{(i)}),\]
where:
  • \(x\) is an observed input variable;

  • \(y\) is an observed target variable;

  • \(theta^{(i)}\) is a sample from the posterior.

Parameters:
  • data_loader (DataLoader) – A data loader.

  • n_posterior_samples (int) – Number of posterior samples to draw in order to compute the log -ikelihood. that would be produced using the posterior distribution state.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An array of log-likelihood values at each posterior sample for each data point.

Return type:

jnp.ndarray

entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]

Estimate the predictive entropy, that is

\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to draw for each input.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive entropy for each input.

Return type:

jnp.ndarray

epistemic_entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]

Estimate the predictive epistemic entropy, that is

\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] + \mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Note that the epistemic entropy above is defined as the difference between the predictive entropy and the aleatoric predictive entropy.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • n_target_samples (int) – Number of target samples to draw for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive epistemic entropy for each input.

Return type:

jnp.ndarray

epistemic_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)

Estimate the predictive epistemic variance of the one-hot encoded target variable, that is

\[\mathbb{E}_{W|D}[\text{Var}_{Y|W, x}[Y]],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive epistemic variance for each input.

Return type:

jnp.ndarray

log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)

Estimate the predictive log-probability density function (a.k.a. log-pdf), that is

\[\log p(y|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(y\) is an observed target variable;

  • \(\mathcal{D}\) is the observed training data set.

Parameters:
  • data_loader (DataLoader) – A data loader.

  • n_posterior_samples (int) – Number of posterior samples to draw in order to approximate the predictive log-pdf. that would be produced using the posterior distribution state.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive log-pdf for each data point.

Return type:

jnp.ndarray

mean(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)

Estimate the predictive mean of the target variable, that is

\[\mathbb{E}_{Y|x, \mathcal{D}}[Y],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(W\) denotes the random model parameters.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive mean for each input.

Return type:

jnp.ndarray

mode(inputs_loader, n_posterior_samples=30, means=None, rng=None, distribute=True)[source]

Estimate the predictive mode of the target variable, that is

\[\text{argmax}_y\ p(y|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(y\) is the target variable to optimize upon.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • means (Optional[jnp.ndarray] = None) – An estimate of the predictive mean.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive mode for each input.

Return type:

jnp.ndarray

quantile(q, inputs_loader, n_target_samples=30, rng=None, distribute=True)[source]

Estimate the q-th quantiles of the predictive probability density function.

Parameters:
  • q (Union[float, Array, List]) – Quantile or sequence of quantiles to compute. Each of these must be between 0 and 1, extremes included.

  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to sample for each input data point.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

Quantile estimate for each quantile and each input. If multiple quantiles q are given, the result’s first axis is over different quantiles.

Return type:

jnp.ndarray

property rng: RandomNumberGenerator

Invoke the random number generator object.

Return type:

The random number generator object.

sample(inputs_loader, n_target_samples=1, return_aux=None, rng=None, distribute=True, **kwargs)

Sample from an approximation of the predictive distribution for each input data point, that is

\[y^{(i)}\sim p(\cdot|x, \mathcal{D}),\]
where:
  • \(x\) is an observed input variable;

  • \(\mathcal{D}\) is the observed training data set;

  • \(y^{(i)}\) is a sample of the target variable for the input \(x\).

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_target_samples (int) – Number of target samples to sample for each input data point.

  • return_aux (Optional[List[str]]) – Return auxiliary objects. We currently support ‘outputs’.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

Samples for each input data point. Optionally, an auxiliary object is returned.

Return type:

Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]

sample_calibrated_outputs(inputs_loader, n_output_samples=1, rng=None, distribute=True)

Sample parameters from the posterior distribution state and compute calibrated outputs.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_output_samples (int) – Number of output samples to draw for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

Samples of calibrated outputs.

Return type:

jnp.ndarray

std(inputs_loader, n_posterior_samples=30, variances=None, rng=None, distribute=True)

Estimate the predictive standard deviation of the target variable, that is

\[\text{Var}_{Y|x, D}[Y],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • variances (Optional[jnp.ndarray]) – An estimate of the predictive variance.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive standard deviation for each input.

Return type:

jnp.ndarray

variance(inputs_loader, n_posterior_samples=30, aleatoric_variances=None, epistemic_variances=None, rng=None, distribute=True)

Estimate the predictive variance of the target variable, that is

\[\text{Var}_{Y|x, D}[Y],\]
where:
  • \(x\) is an observed input variable;

  • \(Y\) is a random target variable;

  • \(\mathcal{D}\) is the observed training data set.

Note that the predictive variance above corresponds to the sum of its aleatoric and epistemic components.

Parameters:
  • inputs_loader (InputsLoader) – A loader of input data points.

  • n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.

  • aleatoric_variances (Optional[jnp.ndarray]) – An estimate of the aleatoric predictive variance for each input.

  • epistemic_variances (Optional[jnp.ndarray]) – An estimate of the epistemic predictive variance for each input.

  • rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.

  • distribute (bool) – Whether to distribute computation over multiple devices, if available.

Returns:

An estimate of the predictive variance for each input.

Return type:

jnp.ndarray