from typing import (
Dict,
Optional,
Tuple,
Union,
)
from flax.core import FrozenDict
import flax.linen as nn
from flax.training.checkpoints import PyTree
import jax
from jax import random
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
from fortuna.model.model_manager.base import ModelManager
from fortuna.typing import (
Array,
Mutable,
Params,
)
from fortuna.utils.data import get_inputs_from_shape
[docs]class RegressionModelManager(ModelManager):
def __init__(
self,
model: nn.Module,
likelihood_log_variance_model: nn.Module,
model_editor: Optional[nn.Module] = None,
):
r"""
Regression model manager class. It orchestrates the forward pass of the model in the probabilistic model.
Parameters
----------
model : nn.Module
A model describing the deterministic relation between inputs and outputs. It characterizes the mean model
of the likelihood function. The outputs must belong to the same space as the target variables.
Let :math:`x` be input variables and :math:`w` the random model parameters. Then the model is described by
a function :math:`\mu(w, x)`.
likelihood_log_variance_model: nn.Module
A model characterizing the log-variance of a Gaussian likelihood function. The outputs must belong to the
same space as the target variables. Let :math:`x` be input variables and :math:`w` the random model
parameters. Then the model is described by a function :math:`\log\sigma^2(w, x)`.
"""
super(RegressionModelManager, self).__init__(model, model_editor)
self.likelihood_log_variance_model = likelihood_log_variance_model
[docs] def apply(
self,
params: Params,
inputs: Array,
mutable: Optional[Mutable] = None,
train: bool = False,
rng: Optional[PRNGKeyArray] = None,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]:
# setup dropout key
if rng is not None:
rng, model_dropout_key, lik_log_var_dropout_key = random.split(rng, 3)
model_rngs = {"dropout": model_dropout_key}
lik_log_var_rngs = {"dropout": lik_log_var_dropout_key}
else:
model_rngs = None
lik_log_var_rngs = None
if mutable is not None:
mutable["model"] = mutable.get("model")
mutable["lik_log_var"] = mutable.get("lik_log_var")
model_has_aux = train and mutable is not None and mutable["model"] is not None
lik_log_var_has_aux = (
train and mutable is not None and mutable["lik_log_var"] is not None
)
def apply_fn(p, x, m_mutable, llv_mutable):
model_outputs = self.model.apply(
p["model"],
x,
train=train,
mutable=m_mutable,
rngs=model_rngs,
)
lik_log_var_outputs = self.likelihood_log_variance_model.apply(
p["lik_log_var"],
x,
train=train,
mutable=llv_mutable,
rngs=lik_log_var_rngs,
)
if isinstance(model_outputs, tuple) and not model_has_aux:
model_outputs = model_outputs[0]
if isinstance(lik_log_var_outputs, tuple) and not lik_log_var_has_aux:
lik_log_var_outputs = lik_log_var_outputs[0]
if model_has_aux:
model_outputs, m_mutable = model_outputs
if lik_log_var_has_aux:
lik_log_var_outputs, llv_mutable = lik_log_var_outputs
self._check_outputs(model_outputs, lik_log_var_outputs)
aux = dict()
if m_mutable or llv_mutable:
aux["mutable"] = dict()
if m_mutable:
aux["mutable"]["model"] = m_mutable
if llv_mutable:
aux["mutable"]["lik_log_var"] = m_mutable
return jnp.concatenate(
(model_outputs, lik_log_var_outputs), axis=-1
), FrozenDict(aux)
if self.model_editor is not None:
outputs, aux = self.model_editor.apply(
params["model_editor"],
apply_fn=lambda p, x: apply_fn(
p,
x,
m_mutable=mutable["model"] if mutable is not None else False,
llv_mutable=mutable["lik_log_var"]
if mutable is not None
else False,
),
model_params=params,
x=inputs,
has_aux=True,
)
else:
outputs, aux = apply_fn(
params,
inputs,
m_mutable=mutable["model"] if mutable is not None else False,
llv_mutable=mutable["lik_log_var"] if mutable is not None else False,
)
if len(aux) > 0:
return outputs, aux
return outputs
[docs] def init(
self, input_shape: Tuple, rng: Optional[PRNGKeyArray] = None, **kwargs
) -> Dict[str, FrozenDict]:
if rng is None:
rng = self.rng.get()
(
rng,
model_params_key,
model_dropout_key,
lik_log_var_params_key,
lik_log_var_dropout_key,
) = random.split(rng, 5)
model_rngs = {"params": model_params_key, "dropout": model_dropout_key}
lik_log_var_rngs = {
"params": lik_log_var_params_key,
"dropout": lik_log_var_params_key,
}
params = dict(
model=self.model.init(model_rngs, jnp.zeros((1,) + input_shape), **kwargs),
lik_log_var=self.likelihood_log_variance_model.init(
lik_log_var_rngs, jnp.zeros((1,) + input_shape), **kwargs
),
)
def apply_fn(p, x):
model_outputs = self.model.apply(
p["model"],
x,
rngs=model_rngs,
)
lik_log_var_outputs = self.likelihood_log_variance_model.apply(
p["lik_log_var"],
x,
rngs=lik_log_var_rngs,
)
self._check_outputs(model_outputs, lik_log_var_outputs)
return jnp.concatenate((model_outputs, lik_log_var_outputs), axis=-1)
if self.model_editor is not None:
if rng is None:
rng = self.rng
rng, params_key, dropout_key = random.split(rng, 3)
rngs = {"params": params_key, "dropout": dropout_key}
params.update(
dict(
model_editor=self.model_editor.init(
rngs,
apply_fn=apply_fn,
model_params=params,
x=get_inputs_from_shape(input_shape),
has_aux=False,
)
)
)
return params
def _check_outputs(
self, model_outputs: jnp.ndarray, lik_log_var_outputs: jnp.ndarray
) -> None:
if model_outputs.shape[-1] != lik_log_var_outputs.shape[-1]:
raise ValueError(
f"""The output dimensions of `model` and `likelihood_log_variance_model must be the same.
However, {model_outputs.shape[-1]} and {lik_log_var_outputs.shape[-1]} were found, respectively."""
)