from typing import Optional
from jax import (
jit,
lax,
vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from fortuna.data.loader import (
ConcatenatedLoader,
DataLoader,
InputsLoader,
)
from fortuna.prob_model.posterior.base import Posterior
from fortuna.prob_model.predictive.base import Predictive
[docs]class ClassificationPredictive(Predictive):
def __init__(self, posterior: Posterior):
"""
Classification predictive distribution class.
Parameters
----------
posterior : Posterior
A posterior distribution object.
"""
super().__init__(posterior)
[docs] def mean(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive mean of the one-hot encoded target variable, that is
.. math::
\mathbb{E}_{\tilde{Y}|x, \mathcal{D}}[\tilde{Y}],
where:
- :math:`x` is an observed input variable;
- :math:`\tilde{Y}` is a one-hot encoded random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng: Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive mean for each input.
"""
return super().mean(inputs_loader, n_posterior_samples, rng, distribute)
[docs] def mode(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
means: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
if means is None:
means = self.mean(
inputs_loader=inputs_loader,
n_posterior_samples=n_posterior_samples,
rng=rng,
distribute=distribute,
)
return jnp.argmax(means, -1)
[docs] def aleatoric_variance(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive aleatoric variance of the one-hot encoded target variable, that is
.. math::
\text{Var}_{W|\mathcal{D}}[\mathbb{E}_{\tilde{Y}|W, x}[\tilde{Y}]],
where:
- :math:`x` is an observed input variable;
- :math:`\tilde{Y}` is a one-hot encoded random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive aleatoric variance for each input.
"""
return super().aleatoric_variance(
inputs_loader, n_posterior_samples, rng, distribute
)
[docs] def epistemic_variance(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
**kwargs,
) -> jnp.ndarray:
r"""
Estimate the predictive epistemic variance of the one-hot encoded target variable, that is
.. math::
\mathbb{E}_{W|D}[\text{Var}_{\tilde{Y}|W, x}[\tilde{Y}]],
where:
- :math:`x` is an observed input variable;
- :math:`\tilde{Y}` is a one-hot encoded random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive epistemic variance for each input.
"""
return super().epistemic_variance(
inputs_loader, n_posterior_samples, rng, distribute
)
[docs] def variance(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
aleatoric_variances: Optional[jnp.ndarray] = None,
epistemic_variances: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive variance of the one-hot encoded target variable, that is
.. math::
\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}],
where:
- :math:`x` is an observed input variable;
- :math:`\tilde{Y}` is a one-hot encoded random target variable;
- :math:`\mathcal{D}` is the observed training data set.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
aleatoric_variances: Optional[jnp.ndarray]
An estimate of the aleatoric predictive variance.
epistemic_variances: Optional[jnp.ndarray]
An estimate of the epistemic predictive variance.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive variance for each input.
"""
return super().variance(
inputs_loader,
n_posterior_samples,
aleatoric_variances,
epistemic_variances,
rng,
distribute,
)
[docs] def std(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
variances: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive standard deviation of the one-hot encoded target variable, that is
.. math::
\sqrt{\text{Var}_{\tilde{Y}|x, D}[\tilde{Y}]},
where:
- :math:`x` is an observed input variable;
- :math:`\tilde{Y}` is a one-hot encoded random target variable;
- :math:`\mathcal{D}` is the observed training data set.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
variances: Optional[jnp.ndarray]
An estimate of the predictive variance.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive standard deviation for each input.
"""
return super().std(
inputs_loader, n_posterior_samples, variances, rng, distribute
)
[docs] def aleatoric_entropy(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive aleatoric entropy, that is
.. math::
-\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],
where:
- :math:`x` is an observed input variable;
- :math:`Y` is a random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive aleatoric entropy for each input.
"""
ensemble_outputs = self.sample_calibrated_outputs(
inputs_loader=inputs_loader,
n_output_samples=n_posterior_samples,
rng=rng,
distribute=distribute,
)
n_classes = ensemble_outputs.shape[-1]
@jit
def _entropy_term(i: int):
targets = i * jnp.ones(ensemble_outputs.shape[1])
def _log_lik_fun(outputs):
return self.likelihood.prob_output_layer.log_prob(outputs, targets)
log_liks = vmap(_log_lik_fun)(ensemble_outputs)
return jnp.mean(jnp.exp(log_liks) * log_liks, 0)
return -jnp.sum(vmap(_entropy_term)(jnp.arange(n_classes)), 0)
[docs] def epistemic_entropy(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive epistemic entropy, that is
.. math::
-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] +
\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]],
where:
- :math:`x` is an observed input variable;
- :math:`Y` is a random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Note that the epistemic entropy above is defined as the difference between the predictive entropy and the
aleatoric predictive entropy.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive epistemic entropy for each input.
"""
ensemble_outputs = self.sample_calibrated_outputs(
inputs_loader=inputs_loader,
n_output_samples=n_posterior_samples,
rng=rng,
distribute=distribute,
)
n_classes = ensemble_outputs.shape[-1]
@jit
def _entropy_term(i: int):
targets = i * jnp.ones(ensemble_outputs.shape[1])
def _log_lik_fun(outputs):
return self.likelihood.prob_output_layer.log_prob(outputs, targets)
log_liks = vmap(_log_lik_fun)(ensemble_outputs)
log_preds = jsp.special.logsumexp(log_liks, 0) - jnp.log(
n_posterior_samples
)
return jnp.exp(log_preds) * log_preds - jnp.mean(
jnp.exp(log_liks) * log_liks, 0
)
return -jnp.sum(vmap(_entropy_term)(jnp.arange(n_classes)), 0)
[docs] def entropy(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate the predictive entropy, that is
.. math::
-\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})],
where:
- :math:`x` is an observed input variable;
- :math:`Y` is a random target variable;
- :math:`\mathcal{D}` is the observed training data set;
- :math:`W` denotes the random model parameters.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples : int
Number of samples to draw from the posterior distribution for each input.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
An estimate of the predictive entropy for each input.
"""
ensemble_outputs = self.sample_calibrated_outputs(
inputs_loader=inputs_loader,
n_output_samples=n_posterior_samples,
rng=rng,
distribute=distribute,
)
n_classes = ensemble_outputs.shape[-1]
@jit
def _entropy_term(i: int):
targets = i * jnp.ones(ensemble_outputs.shape[1])
def _log_lik_fun(outputs):
return self.likelihood.prob_output_layer.log_prob(outputs, targets)
log_liks = vmap(_log_lik_fun)(ensemble_outputs)
log_preds = jsp.special.logsumexp(log_liks, 0) - jnp.log(
n_posterior_samples
)
return jnp.exp(log_preds) * log_preds
return -jnp.sum(vmap(_entropy_term)(jnp.arange(n_classes)), 0)
[docs] def credible_set(
self,
inputs_loader: InputsLoader,
n_posterior_samples: int = 30,
error: float = 0.05,
rng: Optional[PRNGKeyArray] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Estimate credible sets for the target variable. This is done by sorting the class probabilities in descending order
and including classes until the sum > 1-error.
Parameters
----------
inputs_loader : InputsLoader
A loader of input data points.
n_posterior_samples: int
Number of posterior samples to draw for each input.
error: float
The set error. This must be a number between 0 and 1, extremes included. For example,
`error=0.05` corresponds to a 95% level of credibility.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Returns
-------
jnp.ndarray
A credibility set for each of the inputs.
"""
p_classes = self.mean(
inputs_loader=inputs_loader,
n_posterior_samples=n_posterior_samples,
rng=rng,
distribute=distribute,
)
n_classes = jnp.shape(p_classes)[1]
labels = jnp.argsort(p_classes, axis=-1)[:, ::-1]
p_classes_sorted = jnp.sort(p_classes, axis=-1)[:, ::-1]
region_classes = jnp.cumsum(p_classes_sorted, axis=1) > (1 - error)
# Convert CB region into sets
index_true = jnp.argmax(region_classes, axis=1) # first index where True
credible_set = np.zeros(len(index_true), dtype=object)
for s in np.arange(n_classes):
idx = np.where(index_true == s)[0]
credible_set[idx] = labels[idx, : s + 1].tolist()
return credible_set