Skip to content

Commit eb5c446

Browse files
authored
Revert clamp changes from aeee627 (destabilized training) (#252)
1 parent aeee627 commit eb5c446

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

bayesflow/networks/coupling_flow/transforms/affine_transform.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import math
2-
31
import keras.ops as ops
42
from keras.saving import register_keras_serializable as serializable
53

64
from bayesflow.types import Tensor
5+
from bayesflow.utils.keras_utils import shifted_softplus
76
from .transform import Transform
87

98

109
@serializable(package="networks.coupling_flow")
1110
class AffineTransform(Transform):
12-
def __init__(self, clamp: float | None = 1.9, **kwargs):
11+
def __init__(self, clamp: bool = True, **kwargs):
1312
super().__init__(**kwargs)
1413
self.clamp = clamp
1514

@@ -25,12 +24,12 @@ def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]:
2524
def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]:
2625
scale = parameters["scale"]
2726

28-
# soft clamp
29-
if self.clamp is not None:
30-
(2.0 * self.clamp / math.pi) * ops.arctan(scale / self.clamp)
31-
3227
# constrain to positive values
33-
scale = ops.exp(scale)
28+
scale = shifted_softplus(scale)
29+
30+
# soft clamp
31+
if self.clamp:
32+
scale = ops.arcsinh(scale)
3433

3534
parameters["scale"] = scale
3635
return parameters

0 commit comments

Comments
 (0)