Source code for fortuna.prob_model.joint.base

from typing import (
    List,
    Optional,
    Tuple,
    Union,
)

from flax.core import FrozenDict
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

from fortuna.likelihood.base import Likelihood
from fortuna.model.model_manager.state import ModelManagerState
from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState
from fortuna.prob_model.joint.state import JointState
from fortuna.prob_model.prior.base import Prior
from fortuna.typing import (
    Batch,
    CalibMutable,
    CalibParams,
    Mutable,
    Params,
    Shape,
)
from fortuna.utils.data import get_inputs_from_shape
from fortuna.utils.random import WithRNG


[docs]class Joint(WithRNG): def __init__(self, prior: Prior, likelihood: Likelihood): r""" Joint distribution class. This is the joint distribution of target variables and random model parameters given input variables and calibration parameters. It is given by .. math:: p(y, w|x, \phi), where: - :math:`x` is an observed input variable; - :math:`y` is an observed target variable; - :math:`w` denotes the random model parameters; - :math:`\phi` denotes the calibration parameters. Parameters ---------- prior : Prior A prior distribution. likelihood : Likelihood A likelihood function. """ self.prior = prior self.likelihood = likelihood def _batched_log_joint_prob( self, params: Params, batch: Batch, n_data: int, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, return_aux: Optional[List[str]] = None, train: Optional[bool] = False, outputs: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, **kwargs, ) -> Union[float, Tuple[float, dict]]: """ Evaluate the joint batched log-probability density function (a.k.a. log-pdf). Parameters ---------- params : Params The random parameters of the probabilistic model. batch : Batch A batch of data points. n_data : int The total number of data points over which the likelihood is joint. This is used to rescale the batched log-likelihood function to better approximate the full likelihood. mutable : Optional[Mutable] The mutable objects used to evaluate the models. calib_params : Optional[CalibParams] The calibration parameters of the probabilistic model. calib_mutable : Optional[CalibMutable] The calibration mutable objects used to evaluate the calibrators. return_aux : Optional[List[str]] The auxiliary objects to return. We support 'outputs' and 'calib_mutable'. If this argument is not given, no auxiliary object is returned. train : Optional[bool] Whether the method is called during training. outputs : Optional[jnp.ndarray] Pre-computed batch of outputs. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. Returns ------- Union[float, Tuple[float, dict]] The evaluation of the joint batched log-pdf. If `return_aux` is given, the corresponding auxiliary objects are also returned. """ if return_aux is None: return_aux = [] log_prior = self.prior.log_joint_prob(params) outs = self.likelihood._batched_log_joint_prob( params, batch, n_data, mutable=mutable, calib_params=calib_params, calib_mutable=calib_mutable, return_aux=return_aux, train=train, outputs=outputs, rng=rng, **kwargs, ) if len(return_aux) > 0: batched_log_lik, aux = outs return batched_log_lik + log_prior, aux batched_log_lik = outs return batched_log_lik + log_prior def _batched_negative_log_joint_prob( self, params: Params, batch: Batch, n_data: int, mutable: Optional[Mutable] = None, calib_params: Optional[CalibParams] = None, calib_mutable: Optional[CalibMutable] = None, return_aux: Optional[List[str]] = None, train: Optional[bool] = False, outputs: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, **kwargs, ) -> Union[float, Tuple[float, dict]]: outs = self._batched_log_joint_prob( params, batch, n_data, mutable, calib_params, calib_mutable, return_aux, train, outputs, rng, **kwargs, ) if len(return_aux) > 0: loss, aux = outs loss *= -1 return loss, aux return -outs
[docs] def init(self, input_shape: Shape, **kwargs) -> JointState: """ Initialize the state of the joint distribution. Parameters ---------- input_shape : Shape The shape of the input variable. Returns ------- A state of the joint distribution. """ oms = ModelManagerState.init_from_dict( self.likelihood.model_manager.init( input_shape, rng=self.rng.get(), **kwargs ) ) inputs = get_inputs_from_shape(input_shape) outputs = self.likelihood.model_manager.apply( params=oms.params, inputs=inputs, mutable=oms.mutable ) output_dim = ( outputs[0].shape[-1] if isinstance(outputs, (list, tuple)) else outputs.shape[-1] ) ocms = OutputCalibManagerState.init_from_dict( FrozenDict( output_calibrator=self.likelihood.output_calib_manager.init( output_dim=output_dim ) ) ) return JointState.init_from_states(oms, ocms)