Bring in your own objects#
When constructing a probabilistic model, you can bring your own model, prior distribution and output calibrator. Let’s make some examples.
Bring in your own model#
As an example, we show how to construct an arbitrary Convolutional Neural Network (CNN) model.
[1]:
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
We now set it as a model in a probabilistic classifier.
[2]:
from fortuna.prob_model import ProbClassifier
prob_model = ProbClassifier(model=CNN())
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:root:No module named 'transformer' is installed. If you are not working with models from the `transformers` library ignore this warning, otherwise install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: `poetry install --extras 'transformers'`.
Done. Let’s check that it works by initializing its parameters and doing a forward pass.
[3]:
from jax import random
import jax.numpy as jnp
x = jnp.zeros((1, 64, 64, 10))
variables = prob_model.model.init(random.PRNGKey(0), x)
prob_model.model.apply(variables, x)
[3]:
Array([[-2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851,
-2.3025851, -2.3025851, -2.3025851, -2.3025851, -2.3025851]], dtype=float32)
Bring in your own prior distribution#
As an example, we show how to construct a multi-dimensional uniform prior distribution.
[4]:
from fortuna.prob_model.prior import Prior
from fortuna.typing import Params
from typing import Optional
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
class Uniform(Prior):
def log_joint_prob(self, params: Params) -> float:
v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)
def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params:
if rng is None:
rng = self.rng.get()
keys = generate_rng_like_tree(rng, params_like)
return tree_map(
lambda l, k: random.uniform(k, l.shape, l.dtype),
params_like,
keys,
)
In the code below, we test the uniform prior we just created. In order to call sample
, we will set prior.rng
to a RandomNumberGenerator
object, which automatically handles and updates random number generators starting from a random seed. This is usually automatically done by the probabilistic model, so you never need to worry about this. But in this case, since we are testing a derived class of Prior
in isolation, we need this.
[5]:
from fortuna.utils.random import RandomNumberGenerator
prior = Uniform()
prior.rng = RandomNumberGenerator(seed=0)
params_in = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([0.5, 1.0]))
params_out = dict(a=jnp.array([1.0]), b=jnp.array([[0.0]]), c=jnp.array([3.0, 1.0]))
print(f"log-prob(params_in): {prior.log_joint_prob(params_in)}")
print(f"log-prob(params_out): {prior.log_joint_prob(params_out)}")
print(f"sample: {prior.sample(params_in)}")
log-prob(params_in): 0.0
log-prob(params_out): -inf
sample: {'a': Array([0.1542927], dtype=float32), 'b': Array([[0.10147738]], dtype=float32), 'c': Array([0.5286629 , 0.78815365], dtype=float32)}
To use your your uniform prior in Fortuna, just set it as the prior
parameter of your ProbClassifier
or ProbRegressor
.
Bring in your own output calibrator#
As an example, we show how to construct an MLP output calibrator. Mind that an output calibrator is just any Flax model, and as such you could also use the MLP pre-built in Fortuna. However, here we implement one from scratch for educational purpose.
[6]:
import flax.linen as nn
from typing import Tuple
class MLP(nn.Module):
features: Tuple[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
You can now set your MLP as the output calibrator of a probabilistic model, or a calibration model. We do it here for a calibration regressor.
[7]:
from fortuna.output_calib_model import OutputCalibRegressor
calib_model = OutputCalibRegressor(output_calibrator=MLP(features=(4, 2, 1)))
Done. Let’s check that it works by initializing its parameters and doing a forward pass.
[8]:
from jax import random
import jax.numpy as jnp
x = jnp.ones((1, 10))
variables = calib_model.output_calibrator.init(random.PRNGKey(0), x)
calib_model.output_calibrator.apply(variables, x)
[8]:
Array([[-0.00347538]], dtype=float32)