|
| 1 | +from typing import Sequence, Any, Mapping, Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from ...utils.dict_utils import dicts_to_arrays |
| 6 | + |
| 7 | + |
| 8 | +def calibration_error( |
| 9 | + targets: Mapping[str, np.ndarray] | np.ndarray, |
| 10 | + references: Mapping[str, np.ndarray] | np.ndarray, |
| 11 | + resolution: int = 20, |
| 12 | + aggregation: Callable = np.median, |
| 13 | + min_quantile: float = 0.005, |
| 14 | + max_quantile: float = 0.995, |
| 15 | + variable_names: Sequence[str] = None, |
| 16 | +) -> Mapping[str, Any]: |
| 17 | + """Computes an aggregate score for the marginal calibration error over an ensemble of approximate |
| 18 | + posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation |
| 19 | + between an alpha-CI and the relative number of inliers from ``prior_samples`` over multiple alphas in |
| 20 | + (0, 1). |
| 21 | +
|
| 22 | + Parameters |
| 23 | + ---------- |
| 24 | + targets : np.ndarray of shape (num_datasets, num_draws, num_variables) |
| 25 | + The random draws from the approximate posteriors over ``num_datasets`` |
| 26 | + references : np.ndarray of shape (num_datasets, num_variables) |
| 27 | + The corresponding ground-truth values sampled from the prior |
| 28 | + resolution : int, optional, default: 20 |
| 29 | + The number of credibility intervals (CIs) to consider |
| 30 | + aggregation : callable or None, optional, default: np.median |
| 31 | + The function used to aggregate the marginal calibration errors. |
| 32 | + If ``None`` provided, the per-alpha calibration errors will be returned. |
| 33 | + min_quantile : float in (0, 1), optional, default: 0.005 |
| 34 | + The minimum posterior quantile to consider. |
| 35 | + max_quantile : float in (0, 1), optional, default: 0.995 |
| 36 | + The maximum posterior quantile to consider. |
| 37 | + variable_names : Sequence[str], optional (default = None) |
| 38 | + Optional variable names to select from the available variables. |
| 39 | +
|
| 40 | + Returns |
| 41 | + ------- |
| 42 | + result : dict |
| 43 | + Dictionary containing: |
| 44 | + - "values" : float or np.ndarray |
| 45 | + The aggregated calibration error per variable |
| 46 | + - "metric_name" : str |
| 47 | + The name of the metric ("Calibration Error"). |
| 48 | + - "variable_names" : str |
| 49 | + The (inferred) variable names. |
| 50 | + """ |
| 51 | + |
| 52 | + samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names) |
| 53 | + |
| 54 | + # Define alpha values and the corresponding quantile bounds |
| 55 | + alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution) |
| 56 | + regions = 1 - alphas |
| 57 | + lowers = regions / 2 |
| 58 | + uppers = 1 - lowers |
| 59 | + |
| 60 | + # Compute quantiles for each alpha, for each dataset and parameter |
| 61 | + quantiles = np.quantile(samples["targets"], [lowers, uppers], axis=1) |
| 62 | + |
| 63 | + # Shape: (2, resolution, num_datasets, num_params) |
| 64 | + lower_bounds, upper_bounds = quantiles[0], quantiles[1] |
| 65 | + |
| 66 | + # Compute masks for inliers |
| 67 | + lower_mask = lower_bounds <= samples["references"][None, ...] |
| 68 | + upper_mask = upper_bounds >= samples["references"][None, ...] |
| 69 | + |
| 70 | + # Logical AND to identify inliers for each alpha |
| 71 | + inlier_id = np.logical_and(lower_mask, upper_mask) |
| 72 | + |
| 73 | + # Compute the relative number of inliers for each alpha |
| 74 | + alpha_pred = np.mean(inlier_id, axis=1) |
| 75 | + |
| 76 | + # Calculate absolute error between predicted inliers and alpha |
| 77 | + absolute_errors = np.abs(alpha_pred - alphas[:, None]) |
| 78 | + |
| 79 | + # Aggregate errors across alpha |
| 80 | + error = aggregation(absolute_errors, axis=0) |
| 81 | + |
| 82 | + return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names} |
0 commit comments