Predictive distribution¶
The predictive distribution is the component of the probabilistic model responsible for the computation of predictive statistics. We support a classification predictive for classification and a regression predictive for regression. Please find their references below.
- class fortuna.prob_model.predictive.classification.ClassificationPredictive(posterior)[source]¶
Classification predictive distribution class.
- Parameters:
posterior (Posterior) – A posterior distribution object.
- aleatoric_entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive aleatoric entropy, that is
\[-\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive aleatoric entropy for each input.
- Return type:
jnp.ndarray
- aleatoric_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive aleatoric variance of the one-hot encoded target variable, that is
\[\text{Var}_{W|\mathcal{D}}[\mathbb{E}_{\tilde{Y}|W, x}[\tilde{Y}]],\]- where:
\(x\) is an observed input variable;
\(\tilde{Y}\) is a one-hot encoded random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive aleatoric variance for each input.
- Return type:
jnp.ndarray
- conformal_set(train_data_loader, test_inputs_loader, n_posterior_samples=30, error=0.05, rng=None, distribute=True, return_ess=False)[source]¶
Estimate conformal sets for the target variable.
- Parameters:
train_data_loader (DataLoader) – A training data loader.
test_inputs_loader (InputsLoader) – A test inputs loader.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
error (float) – The set error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of confidence.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
return_ess (bool) – Whether to compute effective sample size of importance weights or not.
- Returns:
A list of conformal sets for each test input.
- Return type:
List[List[int]]
- credible_set(inputs_loader, n_posterior_samples=30, error=0.05, rng=None, distribute=True)[source]¶
Estimate credible sets for the target variable. This is done by sorting the class probabilities in descending order and including classes until the sum > 1-error.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of posterior samples to draw for each input.
error (float) – The set error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of credibility.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
A credibility set for each of the inputs.
- Return type:
jnp.ndarray
- ensemble_log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)¶
Compute the log-likelihood at each posterior sample, that is
\[\log p(y|x, theta^{(i)}),\]- where:
\(x\) is an observed input variable;
\(y\) is an observed target variable;
\(theta^{(i)}\) is a sample from the posterior.
- Parameters:
data_loader (DataLoader) – A data loader.
n_posterior_samples (int) – Number of posterior samples to draw in order to compute the log -ikelihood. that would be produced using the posterior distribution state.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An array of log-likelihood values at each posterior sample for each data point.
- Return type:
jnp.ndarray
- entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive entropy, that is
\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive entropy for each input.
- Return type:
jnp.ndarray
- epistemic_entropy(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive epistemic entropy, that is
\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] + \mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
Note that the epistemic entropy above is defined as the difference between the predictive entropy and the aleatoric predictive entropy.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive epistemic entropy for each input.
- Return type:
jnp.ndarray
- epistemic_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)[source]¶
Estimate the predictive epistemic variance of the one-hot encoded target variable, that is
\[\mathbb{E}_{W|D}[\text{Var}_{\tilde{Y}|W, x}[\tilde{Y}]],\]- where:
\(x\) is an observed input variable;
\(\tilde{Y}\) is a one-hot encoded random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive epistemic variance for each input.
- Return type:
jnp.ndarray
- log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)¶
Estimate the predictive log-probability density function (a.k.a. log-pdf), that is
\[\log p(y|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(y\) is an observed target variable;
\(\mathcal{D}\) is the observed training data set.
- Parameters:
data_loader (DataLoader) – A data loader.
n_posterior_samples (int) – Number of posterior samples to draw in order to approximate the predictive log-pdf. that would be produced using the posterior distribution state.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive log-pdf for each data point.
- Return type:
jnp.ndarray
- mean(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive mean of the one-hot encoded target variable, that is
\[\mathbb{E}_{\tilde{Y}|x, \mathcal{D}}[\tilde{Y}],\]- where:
\(x\) is an observed input variable;
\(\tilde{Y}\) is a one-hot encoded random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive mean for each input.
- Return type:
jnp.ndarray
- mode(inputs_loader, n_posterior_samples=30, means=None, rng=None, distribute=True)[source]¶
Estimate the predictive mode of the target variable, that is
\[\text{argmax}_y\ p(y|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(\mathcal{D}\) is the observed training data set;
\(y\) is the target variable to optimize upon.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
means (Optional[jnp.ndarray] = None) – An estimate of the predictive mean.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive mode for each input.
- Return type:
jnp.ndarray
- property rng: RandomNumberGenerator¶
Invoke the random number generator object.
- Return type:
The random number generator object.
- sample(inputs_loader, n_target_samples=1, return_aux=None, rng=None, distribute=True, **kwargs)¶
Sample from an approximation of the predictive distribution for each input data point, that is
\[y^{(i)}\sim p(\cdot|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(\mathcal{D}\) is the observed training data set;
\(y^{(i)}\) is a sample of the target variable for the input \(x\).
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to sample for each input data point.
return_aux (Optional[List[str]]) – Return auxiliary objects. We currently support ‘outputs’.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
Samples for each input data point. Optionally, an auxiliary object is returned.
- Return type:
Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]
- sample_calibrated_outputs(inputs_loader, n_output_samples=1, rng=None, distribute=True)¶
Sample parameters from the posterior distribution state and compute calibrated outputs.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_output_samples (int) – Number of output samples to draw for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
Samples of calibrated outputs.
- Return type:
jnp.ndarray
- std(inputs_loader, n_posterior_samples=30, variances=None, rng=None, distribute=True)[source]¶
Estimate the predictive standard deviation of the one-hot encoded target variable, that is
\[\sqrt{\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}]},\]- where:
\(x\) is an observed input variable;
\(\tilde{Y}\) is a one-hot encoded random target variable;
\(\mathcal{D}\) is the observed training data set.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
variances (Optional[jnp.ndarray]) – An estimate of the predictive variance.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive standard deviation for each input.
- Return type:
jnp.ndarray
- variance(inputs_loader, n_posterior_samples=30, aleatoric_variances=None, epistemic_variances=None, rng=None, distribute=True)[source]¶
Estimate the predictive variance of the one-hot encoded target variable, that is
\[\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}],\]- where:
\(x\) is an observed input variable;
\(\tilde{Y}\) is a one-hot encoded random target variable;
\(\mathcal{D}\) is the observed training data set.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
aleatoric_variances (Optional[jnp.ndarray]) – An estimate of the aleatoric predictive variance.
epistemic_variances (Optional[jnp.ndarray]) – An estimate of the epistemic predictive variance.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive variance for each input.
- Return type:
jnp.ndarray
- class fortuna.prob_model.predictive.regression.RegressionPredictive(posterior)[source]¶
Regression predictive distribution class.
- Parameters:
posterior (Posterior) – A posterior distribution object.
- aleatoric_entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive aleatoric entropy, that is
\[-\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to draw for each input.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive aleatoric entropy for each input.
- Return type:
jnp.ndarray
- aleatoric_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)¶
Estimate the predictive aleatoric variance of the target variable, that is
\[\text{Var}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[Y]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive aleatoric variance for each input.
- Return type:
jnp.ndarray
- credible_interval(inputs_loader, n_target_samples=30, error=0.05, interval_type='two-tailed', rng=None, distribute=True)[source]¶
Estimate credible intervals for the target variable. This is supported only if the target variable is scalar.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to draw for each input.
error (float) – The interval error. This must be a number between 0 and 1, extremes included. For example, error=0.05 corresponds to a 95% level of credibility.
interval_type (str) – The interval type. We support “two-tailed” (default), “right-tailed” and “left-tailed”.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
A credibility interval for each of the inputs.
- Return type:
jnp.ndarray
- ensemble_log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)¶
Compute the log-likelihood at each posterior sample, that is
\[\log p(y|x, theta^{(i)}),\]- where:
\(x\) is an observed input variable;
\(y\) is an observed target variable;
\(theta^{(i)}\) is a sample from the posterior.
- Parameters:
data_loader (DataLoader) – A data loader.
n_posterior_samples (int) – Number of posterior samples to draw in order to compute the log -ikelihood. that would be produced using the posterior distribution state.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An array of log-likelihood values at each posterior sample for each data point.
- Return type:
jnp.ndarray
- entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive entropy, that is
\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to draw for each input.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive entropy for each input.
- Return type:
jnp.ndarray
- epistemic_entropy(inputs_loader, n_posterior_samples=30, n_target_samples=30, rng=None, distribute=True)[source]¶
Estimate the predictive epistemic entropy, that is
\[-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] + \mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
Note that the epistemic entropy above is defined as the difference between the predictive entropy and the aleatoric predictive entropy.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
n_target_samples (int) – Number of target samples to draw for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive epistemic entropy for each input.
- Return type:
jnp.ndarray
- epistemic_variance(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)¶
Estimate the predictive epistemic variance of the one-hot encoded target variable, that is
\[\mathbb{E}_{W|D}[\text{Var}_{Y|W, x}[Y]],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive epistemic variance for each input.
- Return type:
jnp.ndarray
- log_prob(data_loader, n_posterior_samples=30, rng=None, distribute=True, **kwargs)¶
Estimate the predictive log-probability density function (a.k.a. log-pdf), that is
\[\log p(y|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(y\) is an observed target variable;
\(\mathcal{D}\) is the observed training data set.
- Parameters:
data_loader (DataLoader) – A data loader.
n_posterior_samples (int) – Number of posterior samples to draw in order to approximate the predictive log-pdf. that would be produced using the posterior distribution state.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive log-pdf for each data point.
- Return type:
jnp.ndarray
- mean(inputs_loader, n_posterior_samples=30, rng=None, distribute=True)¶
Estimate the predictive mean of the target variable, that is
\[\mathbb{E}_{Y|x, \mathcal{D}}[Y],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set;
\(W\) denotes the random model parameters.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive mean for each input.
- Return type:
jnp.ndarray
- mode(inputs_loader, n_posterior_samples=30, means=None, rng=None, distribute=True)[source]¶
Estimate the predictive mode of the target variable, that is
\[\text{argmax}_y\ p(y|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(\mathcal{D}\) is the observed training data set;
\(y\) is the target variable to optimize upon.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
means (Optional[jnp.ndarray] = None) – An estimate of the predictive mean.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive mode for each input.
- Return type:
jnp.ndarray
- quantile(q, inputs_loader, n_target_samples=30, rng=None, distribute=True)[source]¶
Estimate the q-th quantiles of the predictive probability density function.
- Parameters:
q (Union[float, Array, List]) – Quantile or sequence of quantiles to compute. Each of these must be between 0 and 1, extremes included.
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to sample for each input data point.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
Quantile estimate for each quantile and each input. If multiple quantiles q are given, the result’s first axis is over different quantiles.
- Return type:
jnp.ndarray
- property rng: RandomNumberGenerator¶
Invoke the random number generator object.
- Return type:
The random number generator object.
- sample(inputs_loader, n_target_samples=1, return_aux=None, rng=None, distribute=True, **kwargs)¶
Sample from an approximation of the predictive distribution for each input data point, that is
\[y^{(i)}\sim p(\cdot|x, \mathcal{D}),\]- where:
\(x\) is an observed input variable;
\(\mathcal{D}\) is the observed training data set;
\(y^{(i)}\) is a sample of the target variable for the input \(x\).
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_target_samples (int) – Number of target samples to sample for each input data point.
return_aux (Optional[List[str]]) – Return auxiliary objects. We currently support ‘outputs’.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
Samples for each input data point. Optionally, an auxiliary object is returned.
- Return type:
Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]
- sample_calibrated_outputs(inputs_loader, n_output_samples=1, rng=None, distribute=True)¶
Sample parameters from the posterior distribution state and compute calibrated outputs.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_output_samples (int) – Number of output samples to draw for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
Samples of calibrated outputs.
- Return type:
jnp.ndarray
- std(inputs_loader, n_posterior_samples=30, variances=None, rng=None, distribute=True)¶
Estimate the predictive standard deviation of the target variable, that is
\[\text{Var}_{Y|x, D}[Y],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
variances (Optional[jnp.ndarray]) – An estimate of the predictive variance.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive standard deviation for each input.
- Return type:
jnp.ndarray
- variance(inputs_loader, n_posterior_samples=30, aleatoric_variances=None, epistemic_variances=None, rng=None, distribute=True)¶
Estimate the predictive variance of the target variable, that is
\[\text{Var}_{Y|x, D}[Y],\]- where:
\(x\) is an observed input variable;
\(Y\) is a random target variable;
\(\mathcal{D}\) is the observed training data set.
Note that the predictive variance above corresponds to the sum of its aleatoric and epistemic components.
- Parameters:
inputs_loader (InputsLoader) – A loader of input data points.
n_posterior_samples (int) – Number of samples to draw from the posterior distribution for each input.
aleatoric_variances (Optional[jnp.ndarray]) – An estimate of the aleatoric predictive variance for each input.
epistemic_variances (Optional[jnp.ndarray]) – An estimate of the epistemic predictive variance for each input.
rng (Optional[jax.Array]) – A random number generator. If not passed, this will be taken from the attributes of this class.
distribute (bool) – Whether to distribute computation over multiple devices, if available.
- Returns:
An estimate of the predictive variance for each input.
- Return type:
jnp.ndarray