Bring in your own objects#

When constructing a probabilistic model, you can bring your own model, prior distribution and output calibrator. Let’s make some examples.

Bring in your own model#

As an example, we show how to construct an arbitrary Convolutional Neural Network (CNN) model.

[1]:
import flax.linen as nn


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

We now set it as a model in a probabilistic classifier.

[2]:
from fortuna.prob_model import ProbClassifier

prob_model = ProbClassifier(model=CNN())
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.

Done. Let’s check that it works by initializing its parameters and doing a forward pass.

[3]:
from jax import random
import jax.numpy as jnp

x = jnp.zeros((1, 64, 64, 10))
variables = prob_model.model.init(random.PRNGKey(0), x)
prob_model.model.apply(variables, x)
[3]:
Array([[-2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851,
        -2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851]],      dtype=float32)

Bring in your own prior distribution#

As an example, we show how to construct a multi-dimensional uniform prior distribution.

[4]:
from fortuna.prob_model.prior import Prior
from fortuna.typing import Params
from typing import Optional
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp


class Uniform(Prior):
    def log_joint_prob(self, params: Params) -> float:
        v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
        return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)

    def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params:
        if rng is None:
            rng = self.rng.get()
        keys = generate_rng_like_tree(rng, params_like)
        return tree_map(
            lambda l, k: random.uniform(k, l.shape, l.dtype),
            params_like,
            keys,
        )

In the code below, we test the uniform prior we just created. In order to call sample, we will set prior.rng to a RandomNumberGenerator object, which automatically handles and updates random number generators starting from a random seed. This is usually automatically done by the probabilistic model, so you never need to worry about this. But in this case, since we are testing a derived class of Prior in isolation, we need this.

[5]:
from fortuna.utils.random import RandomNumberGenerator

prior = Uniform()
prior.rng = RandomNumberGenerator(seed=0)
params_in = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([0.5, 1.0]))
params_out = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([3.0, 1.0]))
print(f"log-prob(params_in): {prior.log_joint_prob(params_in)}")
print(f"log-prob(params_out): {prior.log_joint_prob(params_out)}")
print(f"sample: {prior.sample(params_in)}")
log-prob(params_in): 0.0
log-prob(params_out): -inf
sample: {'a': Array([0.1542927], dtype=float32), 'b': Array([[0.10147738]], dtype=float32), 'c': Array([0.5286629 , 0.78815365], dtype=float32)}

To use your your uniform prior in Fortuna, just set it as the prior parameter of your ProbClassifier or ProbRegressor.

Bring in your own output calibrator#

As an example, we show how to construct an MLP output calibrator. Mind that an output calibrator is just any Flax model, and as such you could also use the MLP pre-built in Fortuna. However, here we implement one from scratch for educational purpose.

[6]:
import flax.linen as nn
from typing import Tuple


class MLP(nn.Module):
    features: Tuple[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
        return x

You can now set your MLP as the output calibrator of a probabilistic model, or a calibration model. We do it here for a calibration regressor.

[7]:
from fortuna.output_calib_model import OutputCalibRegressor

calib_model = OutputCalibRegressor(output_calibrator=MLP(features=(4, 2, 1)))

Done. Let’s check that it works by initializing its parameters and doing a forward pass.

[8]:
from jax import random
import jax.numpy as jnp

x = jnp.ones((1, 10))
variables = calib_model.output_calibrator.init(random.PRNGKey(0), x)
calib_model.output_calibrator.apply(variables, x)
[8]:
Array([[-0.00347538]], dtype=float32)