Skip to content

Commit afc1af1

Browse files
authored
Merge pull request #460 from bayesflow-org/dev
v2.0.3
2 parents e1f178c + 52bdb58 commit afc1af1

File tree

112 files changed

+1302
-339
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+1302
-339
lines changed

README.md

+56
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ It provides users and researchers with:
1515
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
1616
fueled by continuous progress in generative AI and Bayesian inference.
1717

18+
> [!IMPORTANT]
19+
> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time.
20+
> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards.
21+
> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update
22+
> for less costly models.
23+
24+
## Important Note for Existing Users
25+
26+
You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library.
27+
While it shares the same overall goals with the 1.x versions, the API is not compatible.
28+
29+
> [!CAUTION]
30+
> A few features, most notably hierarchical models, have not been ported to BayesFlow 2.0+
31+
> yet. We are working on those features and plan to add them soon. You can find the complete
32+
> list in the [FAQ](#faq) below.
33+
34+
The [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb) guide
35+
highlights how concepts and classes relate between the two versions.
36+
1837
## Conceptual Overview
1938

2039
<div align="center">
@@ -216,11 +235,48 @@ while the old version was based on TensorFlow.
216235

217236
-------------
218237

238+
**Question:**
239+
Should I switch to BayesFlow 2.0+ now? Are there features that are still missing?
240+
241+
**Answer:**
242+
In general, we recommend to switch, as the new version is easier to use and will continue
243+
to receive improvements and new features. However, a few features are still missing, so you
244+
might want to wait until everything you need has been ported to BayesFlow 2.0+.
245+
246+
Depending on your needs, you might not want to upgrade yet if one of the following applies:
247+
248+
- You have an ongoing project that uses BayesFlow 1.x, and you do not want to allocate
249+
time for migrating it to the new API.
250+
- You have already trained models in BayesFlow 1.x, that you do not want to re-train
251+
with the new version. Loading models from version 1.x in version 2.0+ is not supported.
252+
- You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge,
253+
this applies to:
254+
* Two-level/Hierarchical models (planned for version 2.1): `TwoLevelGenerativeModel`, `TwoLevelPrior`.
255+
* Sensitivity analysis (partially discontinued): functionality from the `bayesflow.sensitivity` module. This is still
256+
possible, but we do no longer offer a special module for it. We plan to add a tutorial on this, see [#455](https://github.com/bayesflow-org/bayesflow/issues/455).
257+
* MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options
258+
to enable the use of BayesFlow in an MCMC setting.
259+
* Networks: `EvidentialNetwork`.
260+
* Model misspecification detection: MMD test in the summary space (see #384).
261+
262+
If you encounter any functionality that is missing and not listed here, please let us
263+
know by opening an issue.
264+
265+
-------------
266+
219267
**Question:**
220268
I still need the old BayesFlow for some of my projects. How can I install it?
221269

222270
**Answer:**
223271
You can find and install the old Bayesflow version via the `stable-legacy` branch on GitHub.
272+
The corresponding [documentation](https://bayesflow.org/stable-legacy/index.html) can be
273+
accessed by selecting the "stable-legacy" entry in the version picker of the documentation.
274+
275+
You can also install the latest version of BayesFlow v1.x from PyPI using
276+
277+
```
278+
pip install "bayesflow<2.0"
279+
```
224280

225281
-------------
226282

bayesflow/__init__.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,11 @@ def setup():
5050
"in contexts where you need gradients (e.g. custom training loops)."
5151
)
5252

53+
# dynamically add __version__ attribute
54+
from importlib.metadata import version
5355

54-
# dynamically add version dunder variable
55-
try:
56-
from importlib.metadata import version, PackageNotFoundError
56+
globals()["__version__"] = version("bayesflow")
5757

58-
__version__ = version(__name__)
59-
except PackageNotFoundError:
60-
__version__ = "2.0.0"
61-
finally:
62-
del version
63-
del PackageNotFoundError
6458

6559
# call and clean up namespace
6660
setup()

bayesflow/adapters/adapter.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .transforms.filter_transform import Predicate
3030

3131

32-
@serializable
32+
@serializable("bayesflow.adapters")
3333
class Adapter(MutableSequence[Transform]):
3434
"""
3535
Defines an adapter to apply various transforms to data.
@@ -79,7 +79,9 @@ def get_config(self) -> dict:
7979

8080
return serialize(config)
8181

82-
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
82+
def forward(
83+
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
84+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
8385
"""Apply the transforms in the forward direction.
8486
8587
Parameters
@@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8890
The data to be transformed.
8991
stage : str, one of ["training", "validation", "inference"]
9092
The stage the function is called in.
93+
log_det_jac: bool, optional
94+
Whether to return the log determinant of the Jacobian of the transforms.
9195
**kwargs : dict
9296
Additional keyword arguments passed to each transform.
9397
9498
Returns
9599
-------
96-
dict
97-
The transformed data.
100+
dict | tuple[dict, dict]
101+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
98102
"""
99103
data = data.copy()
104+
if not log_det_jac:
105+
for transform in self.transforms:
106+
data = transform(data, stage=stage, **kwargs)
107+
return data
100108

109+
log_det_jac = {}
101110
for transform in self.transforms:
102-
data = transform(data, stage=stage, **kwargs)
111+
transformed_data = transform(data, stage=stage, **kwargs)
112+
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
113+
data = transformed_data
103114

104-
return data
115+
return data, log_det_jac
105116

106-
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
117+
def inverse(
118+
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
119+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
107120
"""Apply the transforms in the inverse direction.
108121
109122
Parameters
@@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
112125
The data to be transformed.
113126
stage : str, one of ["training", "validation", "inference"]
114127
The stage the function is called in.
128+
log_det_jac: bool, optional
129+
Whether to return the log determinant of the Jacobian of the transforms.
115130
**kwargs : dict
116131
Additional keyword arguments passed to each transform.
117132
118133
Returns
119134
-------
120-
dict
121-
The transformed data.
135+
dict | tuple[dict, dict]
136+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
122137
"""
123138
data = data.copy()
139+
if not log_det_jac:
140+
for transform in reversed(self.transforms):
141+
data = transform(data, stage=stage, inverse=True, **kwargs)
142+
return data
124143

144+
log_det_jac = {}
125145
for transform in reversed(self.transforms):
126146
data = transform(data, stage=stage, inverse=True, **kwargs)
147+
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
127148

128-
return data
149+
return data, log_det_jac
129150

130151
def __call__(
131152
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
132-
) -> dict[str, np.ndarray]:
153+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
133154
"""Apply the transforms in the given direction.
134155
135156
Parameters
@@ -145,8 +166,8 @@ def __call__(
145166
146167
Returns
147168
-------
148-
dict
149-
The transformed data.
169+
dict | tuple[dict, dict]
170+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
150171
"""
151172
if inverse:
152173
return self.inverse(data, stage=stage, **kwargs)

bayesflow/adapters/transforms/as_set.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsSet(ElementwiseTransform):
1010
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
1111

bayesflow/adapters/transforms/as_time_series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsTimeSeries(ElementwiseTransform):
1010
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
1111

bayesflow/adapters/transforms/broadcast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .transform import Transform
77

88

9-
@serializable
9+
@serializable("bayesflow.adapters")
1010
class Broadcast(Transform):
1111
"""
1212
Broadcasts arrays or scalars to the shape of a given other array.

bayesflow/adapters/transforms/concatenate.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .transform import Transform
88

99

10-
@serializable
10+
@serializable("bayesflow.adapters")
1111
class Concatenate(Transform):
1212
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
1313
@@ -115,3 +115,37 @@ def extra_repr(self) -> str:
115115
result += f", axis={self.axis}"
116116

117117
return result
118+
119+
def log_det_jac(
120+
self,
121+
data: dict[str, np.ndarray],
122+
log_det_jac: dict[str, np.ndarray],
123+
*,
124+
strict: bool = False,
125+
inverse: bool = False,
126+
**kwargs,
127+
) -> dict[str, np.ndarray]:
128+
# copy to avoid side effects
129+
log_det_jac = log_det_jac.copy()
130+
131+
if inverse:
132+
if log_det_jac.get(self.into) is not None:
133+
raise ValueError(
134+
"Cannot obtain an inverse Jacobian of concatenation. "
135+
"Transform your variables before you concatenate."
136+
)
137+
138+
return log_det_jac
139+
140+
required_keys = set(self.keys)
141+
available_keys = set(log_det_jac.keys())
142+
common_keys = available_keys & required_keys
143+
144+
if len(common_keys) == 0:
145+
return log_det_jac
146+
147+
parts = [log_det_jac.pop(key) for key in common_keys]
148+
149+
log_det_jac[self.into] = sum(parts)
150+
151+
return log_det_jac

bayesflow/adapters/transforms/constrain.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .elementwise_transform import ElementwiseTransform
1212

1313

14-
@serializable
14+
@serializable("bayesflow.adapters")
1515
class Constrain(ElementwiseTransform):
1616
"""
1717
Constrains neural network predictions of a data variable to specified bounds.
@@ -87,6 +87,11 @@ def constrain(x):
8787

8888
def unconstrain(x):
8989
return inverse_sigmoid((x - lower) / (upper - lower))
90+
91+
def ldj(x):
92+
x = (x - lower) / (upper - lower)
93+
return -np.log(x) - np.log1p(-x) - np.log(upper - lower)
94+
9095
case str() as name:
9196
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
9297
case other:
@@ -101,13 +106,22 @@ def constrain(x):
101106

102107
def unconstrain(x):
103108
return inverse_softplus(x - lower)
109+
110+
def ldj(x):
111+
x = x - lower
112+
return x - np.log(np.exp(x) - 1)
113+
104114
case "exp" | "log":
105115

106116
def constrain(x):
107117
return np.exp(x) + lower
108118

109119
def unconstrain(x):
110120
return np.log(x - lower)
121+
122+
def ldj(x):
123+
return -np.log(x - lower)
124+
111125
case str() as name:
112126
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
113127
case other:
@@ -122,13 +136,21 @@ def constrain(x):
122136

123137
def unconstrain(x):
124138
return -inverse_softplus(-(x - upper))
139+
140+
def ldj(x):
141+
x = -(x - upper)
142+
return x - np.log(np.exp(x) - 1)
143+
125144
case "exp" | "log":
126145

127146
def constrain(x):
128147
return -np.exp(-x) + upper
129148

130149
def unconstrain(x):
131150
return -np.log(-x + upper)
151+
152+
def ldj(x):
153+
return -np.log(-x + upper)
132154
case str() as name:
133155
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
134156
case other:
@@ -142,6 +164,7 @@ def unconstrain(x):
142164

143165
self.constrain = constrain
144166
self.unconstrain = unconstrain
167+
self.ldj = ldj
145168

146169
# do this last to avoid serialization issues
147170
match inclusive:
@@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
178201
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
179202
# inverse means network space -> data space, so constrain the data
180203
return self.constrain(data)
204+
205+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
206+
ldj = self.ldj(data)
207+
if inverse:
208+
ldj = -ldj
209+
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/convert_dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class ConvertDType(ElementwiseTransform):
1010
"""
1111
Default transform used to convert all floats from float64 to float32 to be in line with keras framework.

bayesflow/adapters/transforms/drop.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .transform import Transform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class Drop(Transform):
1010
"""
1111
Transform to drop variables from further calculation.
@@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4646

4747
def extra_repr(self) -> str:
4848
return "[" + ", ".join(map(repr, self.keys)) + "]"
49+
50+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
51+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

0 commit comments

Comments
 (0)