Source code for fortuna.prob_output_layer.base
import abc
from typing import Optional
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
from fortuna.typing import Array
from fortuna.utils.random import WithRNG
[docs]class ProbOutputLayer(WithRNG, abc.ABC):
r"""
Abstract probabilistic output layer class. It characterizes the distribution of the target variable given the
calibrated outputs. It can be see as :math:`p(y|\omega)`, where :math:`y` is a target variable and :math:`\omega` a
calibrated output. The probabilistic output layer is not join over different data points, and it acts on them
individually.
"""
[docs] @abc.abstractmethod
def log_prob(self, outputs: Array, targets: Array, **kwargs) -> jnp.ndarray:
"""
Evaluate the log-probability density function (a.k.a. log-pdf) of target variables for each of the outputs.
Parameters
----------
outputs : Array
Calibrated outputs.
targets : Array
Target data points.
Returns
-------
jnp.ndarray
An evaluation of the log-pdf for each output.
"""
pass
[docs] @abc.abstractmethod
def predict(self, outputs: Array, **kwargs) -> jnp.ndarray:
"""
Predict target variables starting from the calibrated outputs.
Parameters
----------
outputs : Array
Calibrated outputs.
Returns
-------
jnp.ndarray
A predictions for each output.
"""
[docs] @abc.abstractmethod
def sample(
self,
n_target_samples: int,
outputs: Array,
rng: Optional[PRNGKeyArray] = None,
**kwargs,
) -> jnp.ndarray:
"""
Sample target variables for each outputs.
Parameters
----------
n_target_samples: int
The number of target samples to draw for each of the outputs.
outputs : Array
Calibrated outputs.
rng : Optional[PRNGKeyArray]
A random number generator. If not passed, this will be taken from the attributes of this class.
Returns
-------
jnp.ndarray
Samples of the target variable for each output.
"""
pass
[docs] @abc.abstractmethod
def mean(self, outputs: Array, **kwargs) -> jnp.ndarray:
"""
Estimate the mean of the target variable given the output with respect to the probabilistic output layer
distribution.
Parameters
----------
outputs : Array
Model outputs
Returns
-------
jnp.ndarray
The estimated mean for each output.
"""
pass
[docs] @abc.abstractmethod
def mode(self, outputs: Array, **kwargs) -> jnp.ndarray:
"""
Estimate the mode of the target variable given the output with respect to the probabilistic output layer
distribution.
Parameters
----------
outputs : Array
Model outputs
Returns
-------
jnp.ndarray
The estimated mode for each output.
"""
pass
[docs] @abc.abstractmethod
def variance(self, outputs: Array, **kwargs) -> jnp.ndarray:
"""
Estimate the variance of the target variable given the output with respect to the probabilistic output layer
distribution.
Parameters
----------
outputs : Array
Model outputs
Returns
-------
jnp.ndarray
The estimated variance for each output.
"""
pass
[docs] def std(self, outputs: Array, variances: Optional[Array] = None) -> jnp.ndarray:
"""
Estimate the standard deviation of the target variable given the output with respect to the probabilistic
output layer distribution.
Parameters
----------
outputs : Array
Model outputs
variances: Optional[Array]
Variance for each output.
Returns
-------
jnp.ndarray
The estimated standard deviation for each output.
"""
return jnp.sqrt(self.variance(outputs)) if variances is None else variances
[docs] @abc.abstractmethod
def entropy(self, outputs: Array, **kwargs) -> jnp.ndarray:
"""
Estimate the entropy of the target variable given the output with respect to the probabilistic output layer
distribution.
Parameters
----------
outputs : Array
Model outputs
Returns
-------
jnp.ndarray
The estimated mean for each output.
"""
pass