Source code for fortuna.prob_model.predictive.regression

from typing import (
    List,
    Optional,
    Union,
)

from jax import (
    lax,
    random,
    vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
import jax.scipy as jsp

from fortuna.data.loader import InputsLoader
from fortuna.prob_model.posterior.base import Posterior
from fortuna.prob_model.predictive.base import Predictive
from fortuna.typing import Array


[docs]class RegressionPredictive(Predictive): def __init__(self, posterior: Posterior): """ Regression predictive distribution class. Parameters ---------- posterior : Posterior A posterior distribution object. """ super().__init__(posterior)
[docs] def mode( self, inputs_loader: InputsLoader, n_posterior_samples: int = 30, means: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> jnp.ndarray: if means is not None: return means return self.mean( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=rng, distribute=distribute, )
[docs] def aleatoric_entropy( self, inputs_loader: InputsLoader, n_posterior_samples: int = 30, n_target_samples: int = 30, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive aleatoric entropy, that is .. math:: -\mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]], where: - :math:`x` is an observed input variable; - :math:`Y` is a random target variable; - :math:`\mathcal{D}` is the observed training data set; - :math:`W` denotes the random model parameters. Parameters ---------- inputs_loader : InputsLoader A loader of input data points. n_target_samples: int Number of target samples to draw for each input. n_posterior_samples : int Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray An estimate of the predictive aleatoric entropy for each input. """ if rng is None: rng = self.rng.get() key1, *keys = random.split(rng, 1 + n_posterior_samples) ensemble_outputs = self.sample_calibrated_outputs( inputs_loader=inputs_loader, n_output_samples=n_posterior_samples, rng=key1, distribute=distribute, ) ensemble_target_samples = lax.map( lambda variables: self.likelihood.prob_output_layer.sample( n_target_samples, variables[0], rng=variables[1] ), (ensemble_outputs, jnp.array(keys)), ) def fun(i, _curr_sum): log_liks = self.likelihood.prob_output_layer.log_prob( ensemble_outputs[i], ensemble_target_samples[i] ) _curr_sum -= jnp.mean(log_liks, 0) return _curr_sum curr_sum = fun(0, 0.0) curr_sum = lax.fori_loop(1, n_posterior_samples, fun, curr_sum) return curr_sum / n_posterior_samples
[docs] def epistemic_entropy( self, inputs_loader: InputsLoader, n_posterior_samples: int = 30, n_target_samples: int = 30, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive epistemic entropy, that is .. math:: -\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})] + \mathbb{E}_{W|\mathcal{D}}[\mathbb{E}_{Y|W, x}[\log p(Y|W, x)]], where: - :math:`x` is an observed input variable; - :math:`Y` is a random target variable; - :math:`\mathcal{D}` is the observed training data set; - :math:`W` denotes the random model parameters. Note that the epistemic entropy above is defined as the difference between the predictive entropy and the aleatoric predictive entropy. Parameters ---------- inputs_loader : InputsLoader A loader of input data points. n_posterior_samples : int Number of samples to draw from the posterior distribution for each input. n_target_samples: int Number of target samples to draw for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray An estimate of the predictive epistemic entropy for each input. """ if rng is None: rng = self.rng.get() key1, *keys = random.split(rng, 1 + n_posterior_samples) ensemble_outputs = self.sample_calibrated_outputs( inputs_loader=inputs_loader, n_output_samples=n_posterior_samples, rng=key1, distribute=distribute, ) ensemble_target_samples = lax.map( lambda variables: self.likelihood.prob_output_layer.sample( n_target_samples, variables[0], rng=variables[1] ), (ensemble_outputs, jnp.array(keys)), ) def fun(i, _curr_sum): @vmap def _log_pred_fun(target_sample: jnp.ndarray): logps = self.likelihood.prob_output_layer.log_prob( ensemble_outputs, target_sample ) return jsp.special.logsumexp(logps, 0) - jnp.log(n_posterior_samples) log_preds = _log_pred_fun(ensemble_target_samples[i]) log_liks = self.likelihood.prob_output_layer.log_prob( ensemble_outputs[i], ensemble_target_samples[i] ) _curr_sum -= jnp.mean(log_preds - log_liks, 0) return _curr_sum curr_sum = fun(0, 0.0) curr_sum = lax.fori_loop(1, n_posterior_samples, fun, curr_sum) return curr_sum / n_posterior_samples
[docs] def entropy( self, inputs_loader: InputsLoader, n_posterior_samples: int = 30, n_target_samples: int = 30, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive entropy, that is .. math:: -\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})], where: - :math:`x` is an observed input variable; - :math:`Y` is a random target variable; - :math:`\mathcal{D}` is the observed training data set; - :math:`W` denotes the random model parameters. Parameters ---------- inputs_loader : InputsLoader A loader of input data points. n_target_samples: int Number of target samples to draw for each input. n_posterior_samples : int Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray An estimate of the predictive entropy for each input. """ if rng is None: rng = self.rng.get() key1, *keys = random.split(rng, 1 + n_posterior_samples) ensemble_outputs = self.sample_calibrated_outputs( inputs_loader=inputs_loader, n_output_samples=n_posterior_samples, rng=key1, distribute=distribute, ) ensemble_target_samples = lax.map( lambda variables: self.likelihood.prob_output_layer.sample( n_target_samples, variables[0], rng=variables[1] ), (ensemble_outputs, jnp.array(keys)), ) def fun(i, _curr_sum): @vmap def _log_pred_fun(target_sample: jnp.ndarray): logps = self.likelihood.prob_output_layer.log_prob( ensemble_outputs, target_sample ) return jsp.special.logsumexp(logps, 0) - jnp.log(n_posterior_samples) log_preds = _log_pred_fun(ensemble_target_samples[i]) _curr_sum -= jnp.mean(log_preds, 0) return _curr_sum curr_sum = fun(0, 0.0) curr_sum = lax.fori_loop(1, n_posterior_samples, fun, curr_sum) return curr_sum / n_posterior_samples
[docs] def credible_interval( self, inputs_loader: InputsLoader, n_target_samples: int = 30, error: float = 0.05, interval_type: str = "two-tailed", rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> jnp.ndarray: r""" Estimate credible intervals for the target variable. This is supported only if the target variable is scalar. Parameters ---------- inputs_loader : InputsLoader A loader of input data points. n_target_samples: int Number of target samples to draw for each input. error: float The interval error. This must be a number between 0 and 1, extremes included. For example, `error=0.05` corresponds to a 95% level of credibility. interval_type: str The interval type. We support "two-tailed" (default), "right-tailed" and "left-tailed". rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray A credibility interval for each of the inputs. """ supported_types = ["two-tailed", "right-tailed", "left-tailed"] if interval_type not in supported_types: raise ValueError( "`type={}` not recognised. Please choose among the following supported types: {}.".format( interval_type, supported_types ) ) q = ( jnp.array([0.5 * error, 1 - 0.5 * error]) if interval_type == "two-tailed" else error if interval_type == "left-tailed" else 1 - error ) qq = self.quantile( q=q, inputs_loader=inputs_loader, n_target_samples=n_target_samples, rng=rng, distribute=distribute, ) if qq.shape[-1] != 1: raise ValueError( """Credibility intervals are only supported for scalar target variables.""" ) if interval_type == "two-tailed": lq, uq = qq.squeeze(2) return jnp.array(list(zip(lq, uq))) else: return qq
[docs] def quantile( self, q: Union[float, Array, List], inputs_loader: InputsLoader, n_target_samples: Optional[int] = 30, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, ) -> Union[float, jnp.ndarray]: r""" Estimate the `q`-th quantiles of the predictive probability density function. Parameters ---------- q : Union[float, Array, List] Quantile or sequence of quantiles to compute. Each of these must be between 0 and 1, extremes included. inputs_loader : InputsLoader A loader of input data points. n_target_samples : int Number of target samples to sample for each input data point. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray Quantile estimate for each quantile and each input. If multiple quantiles `q` are given, the result's first axis is over different quantiles. """ if type(q) == list: q = jnp.array(q) samples = self.sample( inputs_loader=inputs_loader, n_target_samples=n_target_samples, rng=rng, distribute=distribute, ) return jnp.quantile(samples, q, axis=0)