Source code for fortuna.likelihood.regression

from typing import (
    Optional,
    Union,
)

from jax import vmap
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
import numpy as np

from fortuna.data.loader import InputsLoader
from fortuna.likelihood.base import Likelihood
from fortuna.model.model_manager.regression import RegressionModelManager
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_output_layer.regression import RegressionProbOutputLayer
from fortuna.typing import (
    Array,
    CalibMutable,
    CalibParams,
    Mutable,
    Params,
)


[docs]class RegressionLikelihood(Likelihood): def __init__( self, model_manager: RegressionModelManager, prob_output_layer: RegressionProbOutputLayer, output_calib_manager: Optional[OutputCalibManager] = None, ): """ A regression likelihood function class. In this class, the likelihood function is additionally assumed to be a probability density function, i.e. positive and integrating to 1. The likelihood is formed by three objects applied in sequence: the model manager, the output calibrator and the probabilistic output layer. The model manager maps parameters and inputs to outputs. The output calibration takes outputs and returns some calibrated version of them. The probabilistic output layer describes the probability distribution of the calibrated outputs. Parameters ---------- model_manager : ModelManager An model manager. This objects orchestrates the evaluation of the models. prob_output_layer : ProbOutputLayer A probabilistic output layer object. This object characterizes the probability distribution of the target variable given the calibrated outputs. output_calib_manager : Optional[OutputCalibManager] An output calibration manager object. It transforms outputs of the model manager into some calibrated version of them. """ super().__init__( model_manager, prob_output_layer, output_calib_manager=output_calib_manager ) def _batched_mean( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, **kwargs, ) -> jnp.ndarray: outputs = super()._get_batched_calibrated_outputs( params, inputs, mutable, calib_params, calib_mutable, **kwargs ) return outputs[:, : outputs.shape[1] // 2] def _batched_mode( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, **kwargs, ) -> jnp.ndarray: return self._batched_mean( params, inputs, mutable, calib_params=calib_params, calib_mutable=calib_mutable, **kwargs, ) def _batched_variance( self, params: Params, inputs: Array, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, **kwargs, ) -> jnp.ndarray: outputs = super()._get_batched_calibrated_outputs( params, inputs, mutable, calib_params, calib_mutable, **kwargs ) return jnp.exp(outputs[:, outputs.shape[1] // 2 :])
[docs] def entropy( self, params: Params, inputs_loader: InputsLoader, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, n_target_samples: Optional[int] = 30, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, **kwargs, ) -> jnp.ndarray: samples, aux = self.sample( n_target_samples, params, inputs_loader, mutable, calib_params=calib_params, calib_mutable=calib_mutable, return_aux=["outputs"], rng=rng, distribute=distribute, ) outputs = aux["outputs"] @vmap def _log_lik_fun(sample: jnp.ndarray): return self.prob_output_layer.log_prob(outputs, sample, **kwargs) return -jnp.mean(_log_lik_fun(samples), 0)
[docs] def quantile( self, q: Union[float, jnp.ndarray, np.ndarray], params: Optional[Params] = None, inputs_loader: Optional[InputsLoader] = None, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, n_target_samples: Optional[int] = 30, target_samples: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, distribute: bool = True, **kwargs, ) -> Union[float, jnp.ndarray]: """ Estimate the `q`-th quantiles of the likelihood function. Parameters ---------- q: Union[float, jnp.ndarray, np.ndarray] Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. params : Params The random parameters of the probabilistic model. inputs_loader : InputsLoader A loader of input data points. mutable : Optional[Mutable] The mutable objects used to evaluate the models. calib_params : Optional[CalibParams] The calibration parameters of the probabilistic model. calib_mutable : Optional[CalibMutable] The calibration mutable objects used to evaluate the calibrators. n_target_samples : int Number of target samples to sample for each input data point. target_samples: Optional[jnp.ndarray] = None Samples of the target variable for each input, used to estimate the quantiles. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. distribute: bool Whether to distribute computation over multiple devices, if available. Returns ------- jnp.ndarray Quantile estimate for each quantile and each input. If multiple quantiles `q` are given, the result's first axis is over different quantiles. """ if target_samples is None: if params is None or inputs_loader is None: raise ValueError( "if `samples` is not passed, then `params` and `inputs_loader` must be passed." ) target_samples = self.sample( n_target_samples, params, inputs_loader, mutable, calib_params=calib_params, calib_mutable=calib_mutable, rng=rng, distribute=distribute, **kwargs, ) return jnp.quantile(target_samples, q, axis=0)