Output calibrator#
The output calibration calibrates the model outputs. We explicitly support a temperature scaling output calibrator for classification, and a temperature scaling output calibrator for regression.
Alternatively, you can bring in your own output calibrator by overwriting Module.
- class fortuna.output_calibrator.classification.ClassificationTemperatureScaler(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Classification temperature scaling. It scales the logits with a scalar temperature parameters. Let \(o\) be output logits and \(\phi\) be a scalar parameter. Then the scaling can be seen as \(g(\phi, o) = \exp(-\phi) o\).
-
name:
Optional[str] = None#
-
parent:
Union[Type[Module],Type[Scope],Type[_Sentinel],None] = None#
- scope: Optional[Scope] = None#
-
name:
- class fortuna.output_calibrator.regression.RegressionTemperatureScaler(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Regression temperature scaling. It multiplies the variance with a scalar temperature parameters. Let \(v\) be the variance outputs and \(\phi\) be a scalar parameter. Then the scaling can be seen as \(g(\phi, o) = \exp(\phi) v\).
-
name:
Optional[str] = None#
-
parent:
Union[Type[Module],Type[Scope],Type[_Sentinel],None] = None#
- scope: Optional[Scope] = None#
-
name:
- class fortuna.output_calib_model.state.OutputCalibState(step, apply_fn, params, tx, opt_state, encoded_name=Array([79, 117, 116, 112, 117, 116, 67, 97, 108, 105, 98, 83, 116, 97, 116, 101], dtype=int32), frozen_params=None, dynamic_scale=None, mutable=None)[source]#
OutputCalibState(step: int, apply_fn: Callable, params: ‘CalibParams’, tx: optax._src.base.GradientTransformation, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], encoded_name: ‘jnp.ndarray’ = Array([ 79, 117, 116, 112, 117, 116, 67, 97, 108, 105, 98, 83, 116, 97, 116, 101], dtype=int32), frozen_params: ‘Optional[Params]’ = None, dynamic_scale: ‘Optional[dynamic_scale.DynamicScale]’ = None, mutable: ‘Optional[CalibMutable]’ = None)
- classmethod init(params, mutable=None, optimizer=None, **kwargs)[source]#
Initialize an output calibration state.
- Parameters:
params (CalibParams) – The calibration parameters.
optimizer (Optional[OptaxOptimizer]) – An Optax optimizer associated with the calibration state.
mutable (Optional[CalibMutable]) – The calibration mutable objects.
- Returns:
A calibration state.
- Return type:
Any
- classmethod init_from_dict(d, optimizer=None, **kwargs)[source]#
Initialize a calibration state from a dictionary.
- Parameters:
d (Union[Dict, FrozenDict]) – A dictionary with as keys the calibrators and as values their initializations.
optimizer (Optional[OptaxOptimizer]) – An optax optimizer to assign to the calibration state.
- Returns:
A calibration state.
- Return type: