Skip to content

Commit 1dfb490

Browse files
Initial commit
1 parent ba20cd0 commit 1dfb490

29 files changed

+4805
-0
lines changed

README.md

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Probabilistic Downscaling of Climate Variables
2+
3+
Project with Colloquium (MA8114) at TUM: Probabilistic Downscaling of Climate Variables Using Denoising Diffusion Probabilistic Models
4+
5+
Supervisor: Prof. Dr. Rüdiger Westermann (Chair of Computer Graphics and Visualization)\
6+
Advisor: Kevin Höhlein (Chair of Computer Graphics and Visualization)
7+
8+
---
9+
10+
Downscaling combines methods that are used to infer high-resolution information from
11+
low-resolution climate variables. We approach this problem as an image super-resolution
12+
task and employ Denoising Diffusion Probabilistic Model to generate finer-scale variables
13+
conditioned on coarse-scale information. Experiments are conducted on WeatherBench dataset.
14+
Experiments are conducted by analysing temperature at 2 m height above the surface.
15+
16+
![](results/reverse_diffusion_steps.jpg?raw=true)
17+
18+
---
19+
20+
## References
21+
22+
- Liangwei Jiang (2021) Image-Super-Resolution-via-Iterative-Refinement [[Source code](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement#readme)]
23+
- Song et al. (2021) Score-Based Generative Modeling through Stochastic Differential Equations [[Source code](https://github.com/yang-song/score_sde_pytorch)]
24+
- Stephan Rasp, Peter D. Dueben, Sebastian Scher, Jonathan A. Weyn, Soukayna Mouatadid, and Nils Thuerey, 2020. WeatherBench: A benchmark dataset for data-driven weather forecasting. arXiv: [WeatherBench: A benchmark dataset for data-driven weather forecasting
25+
](https://arxiv.org/abs/2002.00469)

config.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Defines configuration parameters for the whole model and dataset.
2+
"""
3+
import argparse
4+
import json
5+
import os
6+
from collections import OrderedDict
7+
from datetime import datetime
8+
9+
10+
def get_current_datetime() -> str:
11+
"""Converts the current datetime to string.
12+
13+
Returns:
14+
String version of current datetime of the form: %y%m%d_%H%M%S.
15+
"""
16+
return datetime.now().strftime("%y%m%d_%H%M%S")
17+
18+
19+
def mkdirs(paths) -> None:
20+
"""Creates directories represented by paths argument.
21+
22+
Args:
23+
paths: Either list of paths or a single path.
24+
"""
25+
if isinstance(paths, str):
26+
os.makedirs(paths, exist_ok=True)
27+
else:
28+
for path in paths:
29+
os.makedirs(path, exist_ok=True)
30+
31+
32+
class Config:
33+
"""Configuration class.
34+
35+
Attributes:
36+
args: Command line aarguments.
37+
root: Configuration json file.
38+
gpu_ids: A list of GPU IDs.
39+
params: A dictionary containing configuration parameters stored in a json file.
40+
name: Name of the experiment.
41+
phase: Either train or val.
42+
distributed: Whether the computation will be distributed among multiple GPUs or not.
43+
log: Path to logs.
44+
tb_logger: Tensorboard logging directory.
45+
results: Validation results directory.
46+
checkpoint: Model checkpoints directory.
47+
resume_state: The path to load the network.
48+
dataset_name: The name of dataset.
49+
dataroot: The path to dataset.
50+
batch_size: Batch size.
51+
num_workers: The number of processes for multi-process data loading.
52+
use_shuffle: Either to shuffle the training data or not.
53+
train_min_date: Minimum date starting from which to read the data for training.
54+
train_max_date: Maximum date until which to read the date for training.
55+
val_min_date: Minimum date starting from which to read the data for validation.
56+
val_max_date: Maximum date until which to read the date for validation.
57+
train_subset_min_date: Minimum date starting from which to read the data for model evaluation on train subset.
58+
train_subset_max_date: Maximum date starting until which to read the data for model evaluation on train subset.
59+
variables: A list of WeatherBench variables.
60+
finetune_norm: Whetehr to fine-tune or train from scratch.
61+
in_channel: The number of channels of input tensor of U-Net.
62+
out_channel: The number of channels of output tensor of U-Net.
63+
inner_channel: Timestep embedding dimension.
64+
norm_groups: The number of groups for group normalization.
65+
channel_multiplier: A tuple specifying the scaling factors of channels.
66+
attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer.
67+
res_blocks: The number of residual blocks.
68+
dropout: Dropout probability.
69+
init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations.
70+
train_schedule: Defines the type of beta schedule for training.
71+
train_n_timestep: Number of diffusion timesteps for training.
72+
train_linear_start: Minimum value of the linear schedule for training.
73+
train_linear_end: Maximum value of the linear schedule for training.
74+
val_schedule: Defines the type of beta schedule for validation.
75+
val_n_timestep: Number of diffusion timesteps for validation.
76+
val_linear_start: Minimum value of the linear schedule for validation.
77+
val_linear_end: Maximum value of the linear schedule for validation.
78+
test_schedule: Defines the type of beta schedule for inference.
79+
test_n_timestep: Number of diffusion timesteps for inference.
80+
test_linear_start: Minimum value of the linear schedule for inference.
81+
test_linear_end: Maximum value of the linear schedule for inference.
82+
conditional: Whether to condition on INTERPOLATED image or not.
83+
diffusion_loss: Either 'l1' or 'l2'.
84+
n_iter: Number of iterations to train.
85+
val_freq: Validation frequency.
86+
save_checkpoint_freq: Model checkpoint frequency.
87+
print_freq: The frequency of displaying training information.
88+
n_val_vis: Number of data points to visualize.
89+
val_vis_freq: Validation data points visualization frequency.
90+
sample_size: Numer of SR images to generate to calculate metrics.
91+
optimizer_type: The name of optimization algorithm. Supported values are 'adam', 'adamw'.
92+
amsgrad: Whether to use the AMSGrad variant of optimizer.
93+
lr: The learning rate.
94+
experiments_root: The path to experiment.
95+
tranform_monthly: Whether to apply transformation monthly or on the whole dataset.
96+
height: U-Net input tensor height value.
97+
"""
98+
99+
def __init__(self, args: argparse.Namespace):
100+
self.args = args
101+
self.root = self.args.config
102+
self.gpu_ids = self.args.gpu_ids
103+
self.params = {}
104+
self.experiments_root = None
105+
self.__parse_configs()
106+
self.name = self.params["name"]
107+
self.phase = self.params["phase"]
108+
self.gpu_ids = self.params["gpu_ids"]
109+
self.distributed = self.params["distributed"]
110+
self.log = self.params["path"]["log"]
111+
self.tb_logger = self.params["path"]["tb_logger"]
112+
self.results = self.params["path"]["results"]
113+
self.checkpoint = self.params["path"]["checkpoint"]
114+
self.resume_state = self.params["path"]["resume_state"]
115+
self.dataset_name = self.params["data"]["name"]
116+
self.dataroot = self.params["data"]["dataroot"]
117+
self.batch_size = self.params["data"]["batch_size"]
118+
self.num_workers = self.params["data"]["num_workers"]
119+
self.use_shuffle = self.params["data"]["use_shuffle"]
120+
self.train_min_date = self.params["data"]["train_min_date"]
121+
self.train_max_date = self.params["data"]["train_max_date"]
122+
self.train_subset_min_date = self.params["data"]["train_subset_min_date"]
123+
self.train_subset_max_date = self.params["data"]["train_subset_max_date"]
124+
self.tranform_monthly = self.params["data"]["apply_tranform_monthly"]
125+
self.transformation = self.params["data"]["transformation"]
126+
self.val_min_date = self.params["data"]["val_min_date"]
127+
self.val_max_date = self.params["data"]["val_max_date"]
128+
self.variables = self.params["data"]["variables"]
129+
self.height = self.params["data"]["height"]
130+
self.finetune_norm = self.params["model"]["finetune_norm"]
131+
self.in_channel = self.params["model"]["unet"]["in_channel"]
132+
self.out_channel = self.params["model"]["unet"]["out_channel"]
133+
self.inner_channel = self.params["model"]["unet"]["inner_channel"]
134+
self.norm_groups = self.params["model"]["unet"]["norm_groups"]
135+
self.channel_multiplier = self.params["model"]["unet"]["channel_multiplier"]
136+
self.attn_res = self.params["model"]["unet"]["attn_res"]
137+
self.res_blocks = self.params["model"]["unet"]["res_blocks"]
138+
self.dropout = self.params["model"]["unet"]["dropout"]
139+
self.init_method = self.params["model"]["unet"]["init_method"]
140+
self.train_schedule = self.params["model"]["beta_schedule"]["train"]["schedule"]
141+
self.train_n_timestep = self.params["model"]["beta_schedule"]["train"]["n_timestep"]
142+
self.train_linear_start = self.params["model"]["beta_schedule"]["train"]["linear_start"]
143+
self.train_linear_end = self.params["model"]["beta_schedule"]["train"]["linear_end"]
144+
self.val_schedule = self.params["model"]["beta_schedule"]["val"]["schedule"]
145+
self.val_n_timestep = self.params["model"]["beta_schedule"]["val"]["n_timestep"]
146+
self.val_linear_start = self.params["model"]["beta_schedule"]["val"]["linear_start"]
147+
self.val_linear_end = self.params["model"]["beta_schedule"]["val"]["linear_end"]
148+
self.test_schedule = self.params["model"]["beta_schedule"]["test"]["schedule"]
149+
self.test_n_timestep = self.params["model"]["beta_schedule"]["test"]["n_timestep"]
150+
self.test_linear_start = self.params["model"]["beta_schedule"]["test"]["linear_start"]
151+
self.test_linear_end = self.params["model"]["beta_schedule"]["test"]["linear_end"]
152+
self.conditional = self.params["model"]["diffusion"]["conditional"]
153+
self.diffusion_loss = self.params["model"]["diffusion"]["loss"]
154+
self.n_iter = self.params["training"]["epoch_n_iter"]
155+
self.val_freq = self.params["training"]["val_freq"]
156+
self.save_checkpoint_freq = self.params["training"]["save_checkpoint_freq"]
157+
self.print_freq = self.params["training"]["print_freq"]
158+
self.n_val_vis = self.params["training"]["n_val_vis"]
159+
self.val_vis_freq = self.params["training"]["val_vis_freq"]
160+
self.sample_size = self.params["training"]["sample_size"]
161+
self.optimizer_type = self.params["training"]["optimizer"]["type"]
162+
self.amsgrad = self.params["training"]["optimizer"]["amsgrad"]
163+
self.lr = self.params["training"]["optimizer"]["lr"]
164+
165+
def __parse_configs(self):
166+
"""Reads configureation json file and stores in params attribute."""
167+
json_str = ""
168+
with open(self.root, "r") as f:
169+
for line in f:
170+
json_str = f"{json_str}{line.split('//')[0]}\n"
171+
172+
self.params = json.loads(json_str, object_pairs_hook=OrderedDict)
173+
174+
if not self.params["path"]["resume_state"]:
175+
self.experiments_root = os.path.join("experiments", f"{self.params['name']}_{get_current_datetime()}")
176+
else:
177+
self.experiments_root = "/".join(self.params["path"]["resume_state"].split("/")[:-2])
178+
179+
for key, path in self.params["path"].items():
180+
if not key.startswith("resume"):
181+
self.params["path"][key] = os.path.join(self.experiments_root, path)
182+
mkdirs(self.params["path"][key])
183+
184+
if self.gpu_ids:
185+
self.params["gpu_ids"] = [int(gpu_id) for gpu_id in self.gpu_ids.split(",")]
186+
gpu_list = self.gpu_ids
187+
else:
188+
gpu_list = ",".join(str(x) for x in self.params["gpu_ids"])
189+
190+
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
191+
self.params["distributed"] = True if len(gpu_list) > 1 else False
192+
193+
def __getattr__(self, item):
194+
"""Returns None when attribute doesn't exist.
195+
196+
Args:
197+
item: Attribute to retrieve.
198+
199+
Returns:
200+
None
201+
"""
202+
return None
203+
204+
def get_hyperparameters_as_dict(self):
205+
"""Returns dictionary containg parsed configuration json file.
206+
"""
207+
return self.params
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
{
2+
"name": "t2m_GlobalSS_Monthly",
3+
"phase": "train",
4+
"gpu_ids": [0],
5+
"path": {
6+
"log": "logs",
7+
"tb_logger": "tb_logger",
8+
"results": "results",
9+
"checkpoint": "checkpoint",
10+
"resume_state": null
11+
},
12+
"data": {
13+
"name": "WeatherBench",
14+
"dataroot": "/mnt/data/papikyan/WeatherBench/numpy",
15+
"batch_size": 16,
16+
"num_workers": 4,
17+
"use_shuffle": true,
18+
"train_min_date": "1979-01-01-00",
19+
"train_max_date": "2016-01-01-00",
20+
"train_subset_min_date": "2014-01-01-00",
21+
"train_subset_max_date": "2016-01-01-00",
22+
"transformation": "GlobalStandardScaling",
23+
"apply_tranform_monthly": true,
24+
"val_min_date": "2016-01-01-00",
25+
"val_max_date": "2018-01-01-00",
26+
"variables": ["t2m"],
27+
"height": 128
28+
},
29+
"model": {
30+
"finetune_norm": false,
31+
"unet": {
32+
"in_channel": 2, // This should be equal to the number of variables * 2. Used in only networks.py 121 line.
33+
"out_channel": 1, // This should be equal to the number of variables.
34+
"inner_channel": 64,
35+
"norm_groups": 32, // 16
36+
"channel_multiplier": [1, 2, 4, 8],
37+
"attn_res": [16], // Possible values are 128, 64, 32, 16 and depends on channel_multipliers.
38+
"res_blocks": 1,
39+
"dropout": 0.7,
40+
"init_method": "kaiming"
41+
},
42+
"beta_schedule": {
43+
"train": {
44+
"schedule": "cosine",
45+
"n_timestep": 2000,
46+
"linear_start": 1e-6,
47+
"linear_end": 1e-2
48+
},
49+
"val": {
50+
"schedule": "cosine",
51+
"n_timestep": 100,
52+
"linear_start": 1e-6,
53+
"linear_end": 1e-2
54+
},
55+
"test": {
56+
"schedule": "cosine",
57+
"n_timestep": 1000,
58+
"linear_start": 1e-6,
59+
"linear_end": 1e-2
60+
}
61+
},
62+
"diffusion": {
63+
"conditional": true,
64+
"loss": "l2"
65+
}
66+
},
67+
"training": {
68+
"epoch_n_iter": 20000,
69+
"val_freq": 2000,
70+
"save_checkpoint_freq": 2000,
71+
"print_freq": 100,
72+
"n_val_vis": 1,
73+
"val_vis_freq": 500,
74+
"sample_size": 5,
75+
"optimizer": {
76+
"type": "adam", // Possible types are ['adam', 'adamw']
77+
"amsgrad": false,
78+
"lr": 5e-5
79+
}
80+
}
81+
}

0 commit comments

Comments
 (0)