Bring in your own objects¶
When constructing a probabilistic model, you can bring your own model, prior distribution and output calibrator. Let’s make some examples.
Bring in your own model¶
As an example, we show how to construct an arbitrary Convolutional Neural Network (CNN) model.
[1]:
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
We now set it as a model in a probabilistic classifier.
[2]:
from fortuna.prob_model import ProbClassifier
prob_model = ProbClassifier(model=CNN())
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.
Done. Let’s check that it works by initializing its parameters and doing a forward pass.
[3]:
from jax import random
import jax.numpy as jnp
x = jnp.zeros((1, 64, 64, 10))
variables = prob_model.model.init(random.PRNGKey(0), x)
prob_model.model.apply(variables, x)
[3]:
Array([[-2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851,
-2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851]], dtype=float32)
Bring in your own prior distribution¶
As an example, we show how to construct a multi-dimensional uniform prior distribution.
[4]:
from fortuna.prob_model.prior import Prior
from fortuna.typing import Params
from typing import Optional
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
import jax
import jax.numpy as jnp
class Uniform(Prior):
def log_joint_prob(self, params: Params) -> float:
v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)
def sample(self, params_like: Params, rng: Optional[jax.Array] = None) -> Params:
if rng is None:
rng = self.rng.get()
keys = generate_rng_like_tree(rng, params_like)
return tree_map(
lambda l, k: random.uniform(k, l.shape, l.dtype),
params_like,
keys,
)
In the code below, we test the uniform prior we just created. In order to call sample, we will set prior.rng to a RandomNumberGenerator object, which automatically handles and updates random number generators starting from a random seed. This is usually automatically done by the probabilistic model, so you never need to worry about this. But in this case, since we are testing a derived class of Prior in isolation, we need this.
[5]:
from fortuna.utils.random import RandomNumberGenerator
prior = Uniform()
prior.rng = RandomNumberGenerator(seed=0)
params_in = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([0.5, 1.0]))
params_out = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([3.0, 1.0]))
print(f"log-prob(params_in): {prior.log_joint_prob(params_in)}")
print(f"log-prob(params_out): {prior.log_joint_prob(params_out)}")
print(f"sample: {prior.sample(params_in)}")
log-prob(params_in): 0.0
log-prob(params_out): -inf
sample: {'a': Array([0.1542927], dtype=float32), 'b': Array([[0.10147738]], dtype=float32), 'c': Array([0.5286629 , 0.78815365], dtype=float32)}
To use your your uniform prior in Fortuna, just set it as the prior parameter of your ProbClassifier or ProbRegressor.
Bring in your own output calibrator¶
As an example, we show how to construct an MLP output calibrator. Mind that an output calibrator is just any Flax model, and as such you could also use the MLP pre-built in Fortuna. However, here we implement one from scratch for educational purpose.
[6]:
import flax.linen as nn
from typing import Tuple
class MLP(nn.Module):
features: Tuple[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
You can now set your MLP as the output calibrator of a probabilistic model, or a calibration model. We do it here for a calibration regressor.
[7]:
from fortuna.output_calib_model import OutputCalibRegressor
calib_model = OutputCalibRegressor(output_calibrator=MLP(features=(4, 2, 1)))
Done. Let’s check that it works by initializing its parameters and doing a forward pass.
[8]:
from jax import random
import jax.numpy as jnp
x = jnp.ones((1, 10))
variables = calib_model.output_calibrator.init(random.PRNGKey(0), x)
calib_model.output_calibrator.apply(variables, x)
[8]:
Array([[-0.00347538]], dtype=float32)