Skip to content

Commit 2f0fdcc

Browse files
committed
run ruff linter + formatter
1 parent f604862 commit 2f0fdcc

File tree

94 files changed

+328
-364
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+328
-364
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ docs/
3232
.tox
3333

3434
# MacOS
35-
.DS_Store
35+
.DS_Store

CITATION.cff

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ preferred-citation:
6969
type: article
7070
url: "https://joss.theoj.org/papers/10.21105/joss.05702"
7171
volume: 8
72-
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"
72+
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"

bayesflow/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from . import (
32
approximators,
43
configurators,

bayesflow/approximators/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
21
from .approximator import Approximator

bayesflow/approximators/approximator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32
from keras.saving import register_keras_serializable
43

@@ -20,7 +19,7 @@
2019
@register_keras_serializable(package="bayesflow.amortizers")
2120
class Approximator(BaseApproximator):
2221
def __init__(self, **kwargs):
23-
""" The main workhorse for learning amortized neural approximators for distributions arising
22+
"""The main workhorse for learning amortized neural approximators for distributions arising
2423
in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal
2524
likelihoods).
2625
@@ -64,14 +63,16 @@ def __init__(self, **kwargs):
6463
if "configurator" not in kwargs:
6564
# try to set up a default configurator
6665
if "inference_variables" not in kwargs:
67-
raise ValueError(f"You must specify either a configurator or arguments for the default configurator.")
66+
raise ValueError("You must specify either a configurator or arguments for the default configurator.")
6867

6968
inference_variables = kwargs.pop("inference_variables")
7069
inference_conditions = kwargs.pop("inference_conditions", None)
7170
summary_variables = kwargs.pop("summary_variables", None)
7271
summary_conditions = kwargs.pop("summary_conditions", None)
7372

74-
kwargs["configurator"] = Configurator(inference_variables, inference_conditions, summary_variables, summary_conditions)
73+
kwargs["configurator"] = Configurator(
74+
inference_variables, inference_conditions, summary_variables, summary_conditions
75+
)
7576

7677
kwargs.setdefault("summary_network", None)
7778
super().__init__(**kwargs)

bayesflow/approximators/base_approximator.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32
from keras.saving import (
43
deserialize_keras_object,
@@ -15,15 +14,23 @@
1514

1615
@register_keras_serializable(package="bayesflow.approximators")
1716
class BaseApproximator(keras.Model):
18-
def __init__(self, inference_network: InferenceNetwork, summary_network: SummaryNetwork, configurator: BaseConfigurator, **kwargs):
17+
def __init__(
18+
self,
19+
inference_network: InferenceNetwork,
20+
summary_network: SummaryNetwork,
21+
configurator: BaseConfigurator,
22+
**kwargs,
23+
):
1924
super().__init__(**keras_kwargs(kwargs))
2025
self.inference_network = inference_network
2126
self.summary_network = summary_network
2227
self.configurator = configurator
2328

2429
@classmethod
2530
def from_config(cls, config: dict, custom_objects=None) -> "BaseApproximator":
26-
config["inference_network"] = deserialize_keras_object(config["inference_network"], custom_objects=custom_objects)
31+
config["inference_network"] = deserialize_keras_object(
32+
config["inference_network"], custom_objects=custom_objects
33+
)
2734
config["summary_network"] = deserialize_keras_object(config["summary_network"], custom_objects=custom_objects)
2835
config["configurator"] = deserialize_keras_object(config["configurator"], custom_objects=custom_objects)
2936

@@ -63,10 +70,12 @@ def evaluate(self, *args, **kwargs):
6370

6471
if val_logs is None:
6572
# https://github.com/keras-team/keras/issues/19835
66-
warnings.warn(f"Found no validation logs due to a bug in keras. "
67-
f"Applying workaround, but incorrect loss values may be logged. "
68-
f"If possible, increase the size of your dataset, "
69-
f"or lower the number of validation steps used.")
73+
warnings.warn(
74+
"Found no validation logs due to a bug in keras. "
75+
"Applying workaround, but incorrect loss values may be logged. "
76+
"If possible, increase the size of your dataset, "
77+
"or lower the number of validation steps used."
78+
)
7079

7180
val_logs = {}
7281

@@ -103,16 +112,18 @@ def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> d
103112
return metrics | summary_metrics | inference_metrics
104113

105114
def compute_loss(self, *args, **kwargs):
106-
raise RuntimeError(f"Use compute_metrics()['loss'] instead.")
115+
raise RuntimeError("Use compute_metrics()['loss'] instead.")
107116

108117
def fit(self, *args, **kwargs):
109118
if not self.built:
110119
try:
111120
dataset = kwargs.get("x") or args[0]
112121
self.build_from_data(dataset[0])
113122
except Exception:
114-
raise RuntimeError(f"Could not automatically build the approximator. Please pass a dataset as the "
115-
f"first argument to `approximator.fit()` or manually call `approximator.build()` "
116-
f"with a dictionary specifying your data shapes.")
123+
raise RuntimeError(
124+
"Could not automatically build the approximator. Please pass a dataset as the "
125+
"first argument to `approximator.fit()` or manually call `approximator.build()` "
126+
"with a dictionary specifying your data shapes."
127+
)
117128

118129
return super().fit(*args, **kwargs)

bayesflow/approximators/jax_approximator.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import jax
32
import keras
43

@@ -13,7 +12,14 @@ def train_step(self, *args, **kwargs):
1312
def test_step(self, *args, **kwargs):
1413
return self.stateless_test_step(*args, **kwargs)
1514

16-
def stateless_compute_metrics(self, trainable_variables: any, non_trainable_variables: any, metrics_variables: any, data: dict[str, Tensor], stage: str = "training") -> (Tensor, tuple):
15+
def stateless_compute_metrics(
16+
self,
17+
trainable_variables: any,
18+
non_trainable_variables: any,
19+
metrics_variables: any,
20+
data: dict[str, Tensor],
21+
stage: str = "training",
22+
) -> (Tensor, tuple):
1723
"""
1824
Things we do for jax:
1925
1. Accept trainable variables as the first argument
@@ -47,11 +53,13 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
4753

4854
grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)
4955

50-
(loss, aux), grads = grad_fn(trainable_variables, non_trainable_variables, metrics_variables, data, stage="training")
56+
(loss, aux), grads = grad_fn(
57+
trainable_variables, non_trainable_variables, metrics_variables, data, stage="training"
58+
)
5159
metrics, non_trainable_variables, metrics_variables = aux
5260

53-
trainable_variables, optimizer_variables = (
54-
self.optimizer.stateless_apply(optimizer_variables, grads, trainable_variables)
61+
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
62+
optimizer_variables, grads, trainable_variables
5563
)
5664

5765
metrics_variables = self._update_loss(loss, metrics_variables)
@@ -62,7 +70,9 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
6270
def stateless_test_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
6371
trainable_variables, non_trainable_variables, metrics_variables = state
6472

65-
loss, aux = self.stateless_compute_metrics(trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation")
73+
loss, aux = self.stateless_compute_metrics(
74+
trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation"
75+
)
6676
metrics, non_trainable_variables, metrics_variables = aux
6777

6878
metrics_variables = self._update_loss(loss, metrics_variables)
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
2-
import numpy as np
3-
41
from bayesflow.types import Tensor
52

63
from .base_approximator import BaseApproximator
74

85

96
class NumpyApproximator(BaseApproximator):
107
def train_step(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
11-
raise NotImplementedError(f"Keras currently has no support for numpy training.")
8+
raise NotImplementedError("Keras currently has no support for numpy training.")

bayesflow/approximators/tensorflow_approximator.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import tensorflow as tf
32

43
from bayesflow.types import Tensor

bayesflow/approximators/torch_approximator.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import torch
32

43
from .base_approximator import BaseApproximator

bayesflow/configurators/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
21
from .base_configurator import BaseConfigurator
32
from .configurator import Configurator

bayesflow/configurators/base_configurator.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from keras.saving import register_keras_serializable
32

43
from bayesflow.types import Tensor

bayesflow/configurators/configurator.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from keras.saving import register_keras_serializable
32

43
from bayesflow.types import Tensor
@@ -14,7 +13,7 @@ def __init__(
1413
inference_variables: list[str],
1514
inference_conditions: list[str] = None,
1615
summary_variables: list[str] = None,
17-
summary_conditions: list[str] = None
16+
summary_conditions: list[str] = None,
1817
):
1918
self.inference_variables = inference_variables
2019
self.inference_conditions = inference_conditions or []
@@ -30,7 +29,7 @@ def get_config(self) -> dict:
3029
"inference_variables": self.inference_variables,
3130
"inference_conditions": self.inference_conditions,
3231
"summary_variables": self.summary_variables,
33-
"summary_conditions": self.summary_conditions
32+
"summary_conditions": self.summary_conditions,
3433
}
3534

3635
def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor | None:

bayesflow/datasets/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
21
from .offline_dataset import OfflineDataset
32
from .online_dataset import OnlineDataset
43
from .rounds_dataset import RoundsDataset
5-

bayesflow/datasets/offline_dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32
import math
43

@@ -7,6 +6,7 @@ class OfflineDataset(keras.utils.PyDataset):
76
"""
87
A dataset that is pre-simulated and stored in memory.
98
"""
9+
1010
def __init__(self, data: dict, batch_size: int, **kwargs):
1111
super().__init__(**kwargs)
1212
self.batch_size = batch_size
@@ -17,7 +17,7 @@ def __init__(self, data: dict, batch_size: int, **kwargs):
1717
self.shuffle()
1818

1919
def __getitem__(self, item: int) -> (dict, dict):
20-
""" Get a batch of pre-simulated data """
20+
"""Get a batch of pre-simulated data"""
2121
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
2222
item = self.indices[item]
2323

@@ -30,5 +30,5 @@ def on_epoch_end(self) -> None:
3030
self.shuffle()
3131

3232
def shuffle(self) -> None:
33-
""" Shuffle the dataset in-place. """
33+
"""Shuffle the dataset in-place."""
3434
self.indices = keras.random.shuffle(self.indices)

bayesflow/datasets/online_dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32

43
from bayesflow.simulators.simulator import Simulator
@@ -8,12 +7,14 @@ class OnlineDataset(keras.utils.PyDataset):
87
"""
98
A dataset that is generated on-the-fly.
109
"""
10+
1111
def __init__(self, simulator: Simulator, batch_size: int, **kwargs):
1212
super().__init__(**kwargs)
1313

1414
if kwargs.get("use_multiprocessing"):
1515
# keras workaround: https://github.com/keras-team/keras/issues/19346
1616
import multiprocessing as mp
17+
1718
mp.set_start_method("spawn", force=True)
1819

1920
self.simulator = simulator

bayesflow/datasets/rounds_dataset.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32

43
from bayesflow.simulators.simulator import Simulator
@@ -8,12 +7,14 @@ class RoundsDataset(keras.utils.PyDataset):
87
"""
98
A dataset that is generated on-the-fly at the beginning of every n-th epoch.
109
"""
10+
1111
def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
1212
super().__init__(**kwargs)
1313

1414
if kwargs.get("use_multiprocessing"):
1515
# keras workaround: https://github.com/keras-team/keras/issues/19346
1616
import multiprocessing as mp
17+
1718
mp.set_start_method("spawn", force=True)
1819

1920
self.simulator = simulator
@@ -27,7 +28,7 @@ def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int
2728
self.regenerate()
2829

2930
def __getitem__(self, item: int) -> (dict, dict):
30-
""" Get a batch of pre-simulated data """
31+
"""Get a batch of pre-simulated data"""
3132
return self.data[item]
3233

3334
@property
@@ -41,5 +42,5 @@ def on_epoch_end(self) -> None:
4142
self.regenerate()
4243

4344
def regenerate(self) -> None:
44-
""" Sample new batches of data from the joint distribution unconditionally """
45+
"""Sample new batches of data from the joint distribution unconditionally"""
4546
self.data = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]

bayesflow/distributions/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
21
from .distribution import Distribution
32
from .diagonal_normal import DiagonalNormal

bayesflow/distributions/diagonal_normal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import math
32

43
import keras
@@ -16,6 +15,7 @@ class DiagonalNormal(Distribution):
1615
- ``_log_unnormalized_prob`` method is used as a loss function
1716
- ``log_prob`` is used for density computation
1817
"""
18+
1919
def __init__(self, mean: float | Tensor = 0.0, std: float | Tensor = 1.0, **kwargs):
2020
super().__init__(**kwargs)
2121
self.mean = mean

bayesflow/distributions/distribution.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import keras
32

43
from bayesflow.types import Shape, Tensor

bayesflow/networks/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
21
from .coupling_flow import CouplingFlow
32
from .deep_set import DeepSet
43
from .flow_matching import FlowMatching
4+
from .inference_network import InferenceNetwork
55
from .mlp import MLP
66
from .resnet import ResNet
77
from .lstnet import LSTNet
8+
from .summary_network import SummaryNetwork
89
from .transformers import SetTransformer
9-
10-
from .inference_network import InferenceNetwork
11-
from .summary_network import SummaryNetwork
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
21
from .coupling_flow import CouplingFlow

bayesflow/networks/coupling_flow/actnorm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from keras import ops
32
from keras.saving import register_keras_serializable
43

@@ -14,21 +13,22 @@ class ActNorm(InvertibleLayer):
1413
Activation Normalization is learned invertible normalization, using
1514
a Scale (s) and Bias (b) vector::
1615
17-
y = s * x + b (forward)
18-
x = (y - b) / s (inverse)
16+
y = s * x + b(forward)
17+
x = (y - b) / s(inverse)
1918
2019
References
2120
----------
2221
23-
.. [1] Kingma, D. P., & Dhariwal, P. (2018).
24-
Glow: Generative flow with invertible 1x1 convolutions.
22+
.. [1] Kingma, D. P., & Dhariwal, P. (2018).
23+
Glow: Generative flow with invertible 1x1 convolutions.
2524
Advances in Neural Information Processing Systems, 31.
2625
2726
.. [2] Salimans, Tim, and Durk P. Kingma. (2016).
2827
Weight normalization: A simple reparameterization to accelerate
2928
training of deep neural networks.
3029
Advances in Neural Information Processing Systems, 29.
3130
"""
31+
3232
def __init__(self, **kwargs):
3333
super().__init__(**keras_kwargs(kwargs))
3434
self.scale = None

0 commit comments

Comments
 (0)