import importlib
import logging
from typing import (
Dict,
Optional,
Type,
)
import flax.linen as nn
from fortuna.data.loader import DataLoader
from fortuna.likelihood.classification import ClassificationLikelihood
from fortuna.model.model_manager.classification import (
ClassificationModelManager,
SNGPClassificationModelManager,
)
from fortuna.model.model_manager.name_to_model_manager import (
ClassificationModelManagers,
)
from fortuna.model_editor.base import ModelEditor
from fortuna.output_calibrator.classification import ClassificationTemperatureScaler
from fortuna.output_calibrator.output_calib_manager.base import OutputCalibManager
from fortuna.prob_model.base import ProbModel
from fortuna.prob_model.calib_config.base import CalibConfig
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
from fortuna.prob_model.posterior.base import PosteriorApproximator
from fortuna.prob_model.posterior.posterior_approximations import (
PosteriorApproximations,
)
from fortuna.prob_model.posterior.swag.swag_approximator import (
SWAGPosteriorApproximator,
)
from fortuna.prob_model.predictive.classification import ClassificationPredictive
from fortuna.prob_model.prior import IsotropicGaussianPrior
from fortuna.prob_model.prior.base import Prior
from fortuna.prob_output_layer.classification import (
ClassificationMaskedProbOutputLayer,
ClassificationProbOutputLayer,
)
from fortuna.typing import Status
from fortuna.utils.data import (
get_input_shape,
get_inputs_from_shape,
)
[docs]class ProbClassifier(ProbModel):
def __init__(
self,
model: nn.Module,
prior: Prior = IsotropicGaussianPrior(),
posterior_approximator: PosteriorApproximator = SWAGPosteriorApproximator(),
output_calibrator: Optional[nn.Module] = ClassificationTemperatureScaler(),
model_editor: Optional[ModelEditor] = None,
seed: int = 0,
):
r"""
A probabilistic classifier class.
Parameters
----------
model : nn.Module
A model describing the deterministic relation between inputs and outputs. The outputs must correspond to
the logits of a softmax probability vector. The output dimension must be the same as the number of classes.
Let :math:`x` be input variables and :math:`w` the random model parameters. Then the model is described by
a function :math:`f(w, x)`, where each component of :math:`f` corresponds to one of the classes.
prior : Prior
A prior distribution object. The default is an isotropic standard Gaussian. Let :math:`w` be the random
model parameters. Then the prior is defined by a distribution :math:`p(w)`.
posterior_approximator : PosteriorApproximator
A posterior approximation method. The default method is SWAG.
output_calibrator : Optional[nn.Module]
An output calibrator object. The default is temperature scaling for classification, which rescales the
logits with a scalar temperature parameter. Given outputs :math:`o` of the model manager, the output
calibrator is described by a function :math:`g(\phi, o)`, where `phi` are deterministic
calibration parameters.
model_editor : ModelEditor
A model_editor objects. It takes the forward pass and transforms the outputs.
seed: int
A random seed.
Attributes
----------
model : nn.Module
See `model` in `Parameters`.
model_manager : ClassificationModelManager
This object orchestrates the model's forward pass.
output_calibrator : nn.Module
See `output_calibrator` in `Parameters`.
prob_output_layer : ClassificationProbOutputLayer
This object characterizes the distribution of target variable given the calibrated outputs. It is defined
by :math:`p(y|o)=\text{Categorical}(y|p=softmax(o))`,
where :math:`o` denotes the calibrated outputs and :math:`y` denotes a target variable.
likelihood : ClassificationLikelihood
The likelihood function. This is defined by
:math:`p(y|w, \phi, x) = \text{Categorical}(y|p=\text{softmax}(g(\phi, f(w, x)))`.
prior : Prior
See `prior` in `Parameters`.
joint : Joint
This object describes the joint distribution of the target variables and the random parameters
given the input variables and the calibration parameters, that is :math:`p(y, w|x, \phi)`.
posterior_approximator : PosteriorApproximator
See `posterior_approximator` in `Parameters`.
posterior : Posterior
This is the posterior approximation of the random parameters given the training data and the
calibration parameters, that is :math:`p(w|\mathcal{D}, \phi)`, where :math:`\mathcal{D}` denotes the
training data set and :math:`\phi` the calibration parameters.
predictive : ClassificationPredictive
This denotes the predictive distribution, that is :math:`p(y|\phi, x, \mathcal{D})`. Its statistics are
approximated via a Monte Carlo approach by sampling from the posterior approximation.
"""
self.model = model
self.prior = prior
self.output_calibrator = output_calibrator
self.output_calib_manager = OutputCalibManager(
output_calibrator=output_calibrator
)
self.prob_output_layer = self._get_prob_output_layer(model)
model_manager_cls = getattr(
ClassificationModelManagers, posterior_approximator.__str__()
).value
self.model_manager = self._get_model_manager(
model, model_editor, model_manager_cls, posterior_approximator
)
self.likelihood = ClassificationLikelihood(
self.model_manager, self.prob_output_layer, self.output_calib_manager
)
self.joint = Joint(self.prior, self.likelihood)
self.posterior = getattr(
PosteriorApproximations, posterior_approximator.__str__()
).value(joint=self.joint, posterior_approximator=posterior_approximator)
self.predictive = ClassificationPredictive(self.posterior)
super().__init__(seed=seed)
def _get_prob_output_layer(self, model: nn.Module) -> ClassificationProbOutputLayer:
try:
# import modules if available
transformers_flax_auto_module = importlib.import_module(
"transformers.models.auto.modeling_flax_auto"
)
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = list(
getattr(
transformers_flax_auto_module,
"FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES",
).values()
)
if str(model.__class__.__name__) in FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES:
prob_output_layer = ClassificationMaskedProbOutputLayer()
else:
prob_output_layer = ClassificationProbOutputLayer()
except ModuleNotFoundError:
prob_output_layer = ClassificationProbOutputLayer()
return prob_output_layer
def _get_model_manager(
self,
model: nn.Module,
model_editor: ModelEditor,
model_manager_cls: Type,
posterior_approximator: PosteriorApproximator,
) -> ClassificationModelManager:
try:
# import modules if available
transformers_module = importlib.import_module("transformers")
fortuna_transformers_classification_module = importlib.import_module(
"fortuna.model.model_manager.transformers.classification"
)
# import relevant classes
FlaxPreTrainedModel = getattr(transformers_module, "FlaxPreTrainedModel")
SNGPHuggingFaceClassificationModelManager = getattr(
fortuna_transformers_classification_module,
"SNGPHuggingFaceClassificationModelManager",
)
HuggingFaceClassificationModelManager = getattr(
fortuna_transformers_classification_module,
"HuggingFaceClassificationModelManager",
)
# load model manager
if (
isinstance(model, FlaxPreTrainedModel)
and model_manager_cls == SNGPClassificationModelManager
):
model_manager = SNGPHuggingFaceClassificationModelManager(
model=model,
model_editor=model_editor,
**posterior_approximator.posterior_method_kwargs,
)
elif isinstance(model, FlaxPreTrainedModel):
model_manager = HuggingFaceClassificationModelManager(
model, model_editor=model_editor
)
else:
model_manager = model_manager_cls(
model=model,
model_editor=model_editor,
**posterior_approximator.posterior_method_kwargs,
)
except ModuleNotFoundError as e:
logging.warning(
"No module named 'transformer' is installed. "
"If you are not working with models from the `transformers` library ignore this warning, otherwise "
"install the optional 'transformers' dependency of Fortuna using poetry. You can do so by entering: "
"`poetry install --extras 'transformers'`."
)
model_manager = model_manager_cls(
model=model,
model_editor=model_editor,
**posterior_approximator.posterior_method_kwargs,
)
return model_manager
def _check_output_dim(self, data_loader: DataLoader):
if data_loader.size == 0:
raise ValueError(
"""`data_loader` is either empty or incorrectly constructed."""
)
output_dim = data_loader.num_unique_labels
for x, y in data_loader:
input_shape = get_input_shape(x)
break
s = self.joint.init(input_shape)
inputs = get_inputs_from_shape(input_shape)
outputs = self.model_manager.apply(
params=s.params, inputs=inputs, mutable=s.mutable
)
model_output_dim = (
outputs[0].shape[-1]
if isinstance(outputs, (list, tuple))
else outputs.shape[-1]
)
if model_output_dim != output_dim:
raise ValueError(
f"""The outputs dimension of `model` must correspond to the number of different classes
in the target variables of `_data_loader`. However, {model_output_dim} and {output_dim} were found,
respectively."""
)
[docs] def train(
self,
train_data_loader: DataLoader,
val_data_loader: Optional[DataLoader] = None,
calib_data_loader: Optional[DataLoader] = None,
fit_config: FitConfig = FitConfig(),
calib_config: CalibConfig = CalibConfig(),
**fit_kwargs,
) -> Dict[str, Status]:
self._check_output_dim(train_data_loader)
return super().train(
train_data_loader,
val_data_loader,
calib_data_loader,
fit_config,
calib_config,
**fit_kwargs,
)
[docs] def calibrate(
self,
calib_data_loader: DataLoader,
val_data_loader: Optional[DataLoader] = None,
calib_config: CalibConfig = CalibConfig(),
) -> Status:
"""
Calibrate the probabilistic classifier.
Parameters
----------
calib_data_loader : DataLoader
A calibration data loader.
val_data_loader : DataLoader
A validation data loader.
calib_config : CalibConfig
An object to configure the calibration.
Returns
-------
Status
A calibration status object. It provides information about the calibration.
"""
self._check_output_dim(calib_data_loader)
if val_data_loader is not None:
self._check_output_dim(val_data_loader)
return super()._calibrate(
uncertainty_fn=(
calib_config.monitor.uncertainty_fn
if calib_config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean
),
calib_data_loader=calib_data_loader,
val_data_loader=val_data_loader,
calib_config=calib_config,
)