Skip to content

Commit fa73d25

Browse files
Enhance readme for ddp cases in ldm tutorials (#1857)
Enhance readme for ddp cases in ldm tutorials Add amp argument in maisi diffusion training notebook as a workaround for #1858. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2dac69c commit fa73d25

File tree

4 files changed

+35
-18
lines changed

4 files changed

+35
-18
lines changed

generation/2d_ldm/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ torchrun \
5757
--master_addr=localhost --master_port=1234 \
5858
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
5959
```
60+
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
6061

6162
<p align="center">
6263
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
@@ -88,6 +89,8 @@ torchrun \
8889
--master_addr=localhost --master_port=1234 \
8990
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
9091
```
92+
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
93+
9194
<p align="center">
9295
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
9396
&nbsp; &nbsp; &nbsp; &nbsp;

generation/3d_ldm/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ torchrun \
6161
--master_addr=localhost --master_port=1234 \
6262
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
6363
```
64+
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
6465

6566
<p align="center">
6667
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
@@ -87,6 +88,7 @@ torchrun \
8788
--master_addr=localhost --master_port=1234 \
8889
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
8990
```
91+
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
9092
<p align="center">
9193
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
9294
&nbsp; &nbsp; &nbsp; &nbsp;

generation/maisi/maisi_diff_unet_training_tutorial.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,9 @@
429429
"\n",
430430
"After all latent features have been created, we will initiate the multi-GPU script to train the latent diffusion model.\n",
431431
"\n",
432-
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1."
432+
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1.\n",
433+
"\n",
434+
"Please be aware that using the H100 GPU may occasionally result in random segmentation faults. To avoid this issue, you can disable AMP by setting the `--no_amp` flag."
433435
]
434436
},
435437
{

generation/maisi/scripts/diff_model_train.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch.nn.parallel import DistributedDataParallel
2525

2626
import monai
27-
from monai.data import ThreadDataLoader, partition_dataset
27+
from monai.data import DataLoader, partition_dataset
2828
from monai.transforms import Compose
2929
from monai.utils import first
3030

@@ -50,7 +50,7 @@ def load_filenames(data_list_path: str) -> list:
5050

5151
def prepare_data(
5252
train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
53-
) -> ThreadDataLoader:
53+
) -> DataLoader:
5454
"""
5555
Prepare training data.
5656
@@ -62,7 +62,7 @@ def prepare_data(
6262
batch_size (int): Mini-batch size.
6363
6464
Returns:
65-
ThreadDataLoader: Data loader for training.
65+
DataLoader: Data loader for training.
6666
"""
6767

6868
def _load_data_from_file(file_path, key):
@@ -90,7 +90,7 @@ def _load_data_from_file(file_path, key):
9090
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
9191
)
9292

93-
return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
93+
return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
9494

9595

9696
def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
@@ -124,14 +124,12 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
124124
return unet
125125

126126

127-
def calculate_scale_factor(
128-
train_loader: ThreadDataLoader, device: torch.device, logger: logging.Logger
129-
) -> torch.Tensor:
127+
def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
130128
"""
131129
Calculate the scaling factor for the dataset.
132130
133131
Args:
134-
train_loader (ThreadDataLoader): Data loader for training.
132+
train_loader (DataLoader): Data loader for training.
135133
device (torch.device): Device to use for calculation.
136134
logger (logging.Logger): Logger for logging information.
137135
@@ -181,7 +179,7 @@ def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> t
181179
def train_one_epoch(
182180
epoch: int,
183181
unet: torch.nn.Module,
184-
train_loader: ThreadDataLoader,
182+
train_loader: DataLoader,
185183
optimizer: torch.optim.Optimizer,
186184
lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
187185
loss_pt: torch.nn.L1Loss,
@@ -193,14 +191,15 @@ def train_one_epoch(
193191
device: torch.device,
194192
logger: logging.Logger,
195193
local_rank: int,
194+
amp: bool = True,
196195
) -> torch.Tensor:
197196
"""
198197
Train the model for one epoch.
199198
200199
Args:
201200
epoch (int): Current epoch number.
202201
unet (torch.nn.Module): UNet model.
203-
train_loader (ThreadDataLoader): Data loader for training.
202+
train_loader (DataLoader): Data loader for training.
204203
optimizer (torch.optim.Optimizer): Optimizer.
205204
lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
206205
loss_pt (torch.nn.L1Loss): Loss function.
@@ -212,6 +211,7 @@ def train_one_epoch(
212211
device (torch.device): Device to use for training.
213212
logger (logging.Logger): Logger for logging information.
214213
local_rank (int): Local rank for distributed training.
214+
amp (bool): Use automatic mixed precision training.
215215
216216
Returns:
217217
torch.Tensor: Training loss for the epoch.
@@ -237,7 +237,7 @@ def train_one_epoch(
237237

238238
optimizer.zero_grad(set_to_none=True)
239239

240-
with autocast("cuda", enabled=True):
240+
with autocast("cuda", enabled=amp):
241241
noise = torch.randn(
242242
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
243243
)
@@ -256,9 +256,13 @@ def train_one_epoch(
256256

257257
loss = loss_pt(noise_pred.float(), noise.float())
258258

259-
scaler.scale(loss).backward()
260-
scaler.step(optimizer)
261-
scaler.update()
259+
if amp:
260+
scaler.scale(loss).backward()
261+
scaler.step(optimizer)
262+
scaler.update()
263+
else:
264+
loss.backward()
265+
optimizer.step()
262266

263267
lr_scheduler.step()
264268

@@ -312,14 +316,18 @@ def save_checkpoint(
312316
)
313317

314318

315-
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
319+
def diff_model_train(
320+
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
321+
) -> None:
316322
"""
317323
Main function to train a diffusion model.
318324
319325
Args:
320326
env_config_path (str): Path to the environment configuration file.
321327
model_config_path (str): Path to the model configuration file.
322328
model_def_path (str): Path to the model definition file.
329+
num_gpus (int): Number of GPUs to use for training.
330+
amp (bool): Use automatic mixed precision training.
323331
"""
324332
args = load_config(env_config_path, model_config_path, model_def_path)
325333
local_rank, world_size, device = initialize_distributed(num_gpus)
@@ -357,7 +365,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
357365
)[local_rank]
358366

359367
train_loader = prepare_data(
360-
train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
368+
train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"]
361369
)
362370

363371
unet = load_unet(args, device, logger)
@@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
392400
device,
393401
logger,
394402
local_rank,
403+
amp=amp,
395404
)
396405

397406
loss_torch = loss_torch.tolist()
@@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
431440
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
432441
)
433442
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
443+
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
434444

435445
args = parser.parse_args()
436-
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
446+
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)

0 commit comments

Comments
 (0)