Skip to content

Commit 7cc4969

Browse files
Merging dev with workflow (#277)
* WIP workflow * Make constructor more general * Fix docs * Ensure working metrics * Add working compute_diagnostics * Major changes to utils, diagnostics, and workflows * Get rid of long training output * add as_time_series transform * Final cleanup --------- Co-authored-by: Paul-Christian Bürkner <paul.buerkner@gmail.com>
1 parent 3f82ce1 commit 7cc4969

28 files changed

+1391
-411
lines changed

bayesflow/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
distributions,
88
networks,
99
simulators,
10+
workflows,
1011
utils,
1112
)
1213

14+
from .workflows import BasicWorkflow
1315
from .approximators import ContinuousApproximator
1416
from .adapters import Adapter
1517
from .datasets import OfflineDataset, OnlineDataset, DiskDataset

bayesflow/adapters/adapter.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .transforms import (
1111
AsSet,
12+
AsTimeSeries,
1213
Broadcast,
1314
Concatenate,
1415
Constrain,
@@ -112,6 +113,14 @@ def as_set(self, keys: str | Sequence[str]):
112113
self.transforms.append(transform)
113114
return self
114115

116+
def as_time_series(self, keys: str | Sequence[str]):
117+
if isinstance(keys, str):
118+
keys = [keys]
119+
120+
transform = MapTransform({key: AsTimeSeries() for key in keys})
121+
self.transforms.append(transform)
122+
return self
123+
115124
def broadcast(
116125
self, keys: str | Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1
117126
):

bayesflow/adapters/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .as_set import AsSet
2+
from .as_time_series import AsTimeSeries
23
from .broadcast import Broadcast
34
from .concatenate import Concatenate
45
from .constrain import Constrain

bayesflow/adapters/transforms/as_set.py

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ class AsSet(ElementwiseTransform):
1111
This is useful, for example, in a linear regression context where we can index
1212
the observations in arbitrary order and always get the same regression line.
1313
14+
Currently, all this transform does is to ensure that the variable
15+
arrays are at least 3D. The 2rd dimension is treated as the
16+
set dimension and the 3rd dimension as the data dimension.
17+
In the future, the transform will have more advanced behavior
18+
to better ensure the correct treatment of sets.
19+
1420
Useage:
1521
1622
adapter = (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
3+
from .elementwise_transform import ElementwiseTransform
4+
5+
6+
class AsTimeSeries(ElementwiseTransform):
7+
"""
8+
The `.as_time_series` transform can be used to indicate that
9+
variables shall be treated as time series.
10+
11+
Currently, all this transformation does is to ensure that the variable
12+
arrays are at least 3D. The 2rd dimension is treated as the
13+
time series dimension and the 3rd dimension as the data dimension.
14+
In the future, the transform will have more advanced behavior
15+
to better ensure the correct treatment of time series data.
16+
17+
Useage:
18+
19+
adapter = (
20+
bf.Adapter()
21+
.as_time_series(["x", "y"])
22+
)
23+
"""
24+
25+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
26+
return np.atleast_3d(data)
27+
28+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
29+
if data.shape[2] == 1:
30+
return np.squeeze(data, axis=2)
31+
32+
return data

bayesflow/diagnostics/__init__.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from .plots import calibration_ecdf
2-
from .plots import calibration_histogram
3-
from .plots import loss
4-
from .plots import mc_calibration
5-
from .plots import mc_confusion_matrix
6-
from .plots import mmd_hypothesis_test
7-
from .plots import pairs_posterior
8-
from .plots import pairs_prior
9-
from .plots import pairs_samples
10-
from .plots import recovery
11-
from .plots import z_score_contraction
1+
from .metrics import root_mean_squared_error, calibration_error, posterior_contraction
2+
3+
from .plots import (
4+
calibration_ecdf,
5+
calibration_histogram,
6+
loss,
7+
mc_calibration,
8+
mc_confusion_matrix,
9+
mmd_hypothesis_test,
10+
pairs_posterior,
11+
pairs_samples,
12+
recovery,
13+
z_score_contraction,
14+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .calibration_error import calibration_error
2+
from .posterior_contraction import posterior_contraction
3+
from .root_mean_squared_error import root_mean_squared_error
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Sequence, Any, Mapping, Callable
2+
3+
import numpy as np
4+
5+
from ...utils.dict_utils import dicts_to_arrays
6+
7+
8+
def calibration_error(
9+
targets: Mapping[str, np.ndarray] | np.ndarray,
10+
references: Mapping[str, np.ndarray] | np.ndarray,
11+
resolution: int = 20,
12+
aggregation: Callable = np.median,
13+
min_quantile: float = 0.005,
14+
max_quantile: float = 0.995,
15+
variable_names: Sequence[str] = None,
16+
) -> Mapping[str, Any]:
17+
"""Computes an aggregate score for the marginal calibration error over an ensemble of approximate
18+
posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation
19+
between an alpha-CI and the relative number of inliers from ``prior_samples`` over multiple alphas in
20+
(0, 1).
21+
22+
Parameters
23+
----------
24+
targets : np.ndarray of shape (num_datasets, num_draws, num_variables)
25+
The random draws from the approximate posteriors over ``num_datasets``
26+
references : np.ndarray of shape (num_datasets, num_variables)
27+
The corresponding ground-truth values sampled from the prior
28+
resolution : int, optional, default: 20
29+
The number of credibility intervals (CIs) to consider
30+
aggregation : callable or None, optional, default: np.median
31+
The function used to aggregate the marginal calibration errors.
32+
If ``None`` provided, the per-alpha calibration errors will be returned.
33+
min_quantile : float in (0, 1), optional, default: 0.005
34+
The minimum posterior quantile to consider.
35+
max_quantile : float in (0, 1), optional, default: 0.995
36+
The maximum posterior quantile to consider.
37+
variable_names : Sequence[str], optional (default = None)
38+
Optional variable names to select from the available variables.
39+
40+
Returns
41+
-------
42+
result : dict
43+
Dictionary containing:
44+
- "values" : float or np.ndarray
45+
The aggregated calibration error per variable
46+
- "metric_name" : str
47+
The name of the metric ("Calibration Error").
48+
- "variable_names" : str
49+
The (inferred) variable names.
50+
"""
51+
52+
samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
53+
54+
# Define alpha values and the corresponding quantile bounds
55+
alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution)
56+
regions = 1 - alphas
57+
lowers = regions / 2
58+
uppers = 1 - lowers
59+
60+
# Compute quantiles for each alpha, for each dataset and parameter
61+
quantiles = np.quantile(samples["targets"], [lowers, uppers], axis=1)
62+
63+
# Shape: (2, resolution, num_datasets, num_params)
64+
lower_bounds, upper_bounds = quantiles[0], quantiles[1]
65+
66+
# Compute masks for inliers
67+
lower_mask = lower_bounds <= samples["references"][None, ...]
68+
upper_mask = upper_bounds >= samples["references"][None, ...]
69+
70+
# Logical AND to identify inliers for each alpha
71+
inlier_id = np.logical_and(lower_mask, upper_mask)
72+
73+
# Compute the relative number of inliers for each alpha
74+
alpha_pred = np.mean(inlier_id, axis=1)
75+
76+
# Calculate absolute error between predicted inliers and alpha
77+
absolute_errors = np.abs(alpha_pred - alphas[:, None])
78+
79+
# Aggregate errors across alpha
80+
error = aggregation(absolute_errors, axis=0)
81+
82+
return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Sequence, Any, Mapping, Callable
2+
3+
import numpy as np
4+
5+
from ...utils.dict_utils import dicts_to_arrays
6+
7+
8+
def posterior_contraction(
9+
targets: Mapping[str, np.ndarray] | np.ndarray,
10+
references: Mapping[str, np.ndarray] | np.ndarray,
11+
aggregation: Callable = np.median,
12+
variable_names: Sequence[str] = None,
13+
) -> Mapping[str, Any]:
14+
"""Computes the posterior contraction (PC) from prior to posterior for the given samples.
15+
16+
Parameters
17+
----------
18+
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
19+
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
20+
for each data set from `num_datasets`.
21+
references : np.ndarray of shape (num_datasets, num_variables)
22+
Prior samples, comprising `num_datasets` ground truths.
23+
aggregation : callable, optional (default = np.median)
24+
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
25+
variable_names : Sequence[str], optional (default = None)
26+
Optional variable names to select from the available variables.
27+
28+
Returns
29+
-------
30+
result : dict
31+
Dictionary containing:
32+
- "values" : float or np.ndarray
33+
The aggregated posterior contraction per variable
34+
- "metric_name" : str
35+
The name of the metric ("Posterior Contraction").
36+
- "variable_names" : str
37+
The (inferred) variable names.
38+
39+
Notes
40+
-----
41+
Posterior contraction measures the reduction in uncertainty from the prior to the posterior.
42+
Values close to 1 indicate strong contraction (high reduction in uncertainty), while values close to 0
43+
indicate low contraction.
44+
"""
45+
46+
samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
47+
48+
post_vars = samples["targets"].var(axis=1, ddof=1)
49+
prior_vars = samples["references"].var(axis=0, keepdims=True, ddof=1)
50+
contraction = 1 - (post_vars / prior_vars)
51+
contraction = aggregation(contraction, axis=0)
52+
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Sequence, Any, Mapping, Callable
2+
3+
import numpy as np
4+
5+
from ...utils.dict_utils import dicts_to_arrays
6+
7+
8+
def root_mean_squared_error(
9+
targets: Mapping[str, np.ndarray] | np.ndarray,
10+
references: Mapping[str, np.ndarray] | np.ndarray,
11+
normalize: bool = True,
12+
aggregation: Callable = np.median,
13+
variable_names: Sequence[str] = None,
14+
) -> Mapping[str, Any]:
15+
"""Computes the (Normalized) Root Mean Squared Error (RMSE/NRMSE) for the given posterior and prior samples.
16+
17+
Parameters
18+
----------
19+
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
20+
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
21+
for each data set from `num_datasets`.
22+
references : np.ndarray of shape (num_datasets, num_variables)
23+
Prior samples, comprising `num_datasets` ground truths.
24+
normalize : bool, optional (default = True)
25+
Whether to normalize the RMSE using the range of the prior samples.
26+
aggregation : callable, optional (default = np.median)
27+
Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
28+
variable_names : Sequence[str], optional (default = None)
29+
Optional variable names to select from the available variables.
30+
31+
Notes
32+
-----
33+
Aggregation is performed after computing the RMSE for each posterior draw, instead of first aggregating
34+
the posterior draws and then computing the RMSE between aggregates and ground truths.
35+
36+
Returns
37+
-------
38+
result : dict
39+
Dictionary containing:
40+
- "values" : np.ndarray
41+
The aggregated (N)RMSE for each variable.
42+
- "metric_name" : str
43+
The name of the metric ("RMSE" or "NRMSE").
44+
- "variable_names" : str
45+
The (inferred) variable names.
46+
"""
47+
48+
samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
49+
50+
rmse = np.sqrt(np.mean((samples["targets"] - samples["references"][:, None, :]) ** 2, axis=0))
51+
52+
if normalize:
53+
rmse /= (samples["references"].max(axis=0) - samples["references"].min(axis=0))[None, :]
54+
metric_name = "NRMSE"
55+
else:
56+
metric_name = "RMSE"
57+
58+
rmse = aggregation(rmse, axis=0)
59+
return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]}

bayesflow/diagnostics/plots/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .mc_confusion_matrix import mc_confusion_matrix
66
from .mmd_hypothesis_test import mmd_hypothesis_test
77
from .pairs_posterior import pairs_posterior
8-
from .pairs_prior import pairs_prior
98
from .pairs_samples import pairs_samples
109
from .recovery import recovery
1110
from .z_score_contraction import z_score_contraction

0 commit comments

Comments
 (0)