From 6168e1a82a0a7ae4279e4e27570832c99feb67f3 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 9 Jul 2020 14:11:49 +0300 Subject: [PATCH 1/6] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) --- models/cycle_gan_model.py | 40 ++++++++++++++++++++++++++++++++------- models/pix2pix_model.py | 30 ++++++++++++++++++++++++++--- options/base_options.py | 7 ++++++- 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 15bb72d8ddc..f536503c8f8 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -3,7 +3,12 @@ from util.image_pool import ImagePool from .base_model import BaseModel from . import networks +from torch.utils.checkpoint import checkpoint +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") class CycleGANModel(BaseModel): """ @@ -96,6 +101,10 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -112,11 +121,17 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + if not self.opt.checkpointing: + self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + else: + self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + if not self.opt.checkpointing: + self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + else: + self.rec_B = checkpoint(self.netG_A, self.fake_A) - def backward_D_basic(self, netD, real, fake): + def backward_D_basic(self, netD, real, fake, loss_id): """Calculate GAN loss for the discriminator Parameters: @@ -135,18 +150,23 @@ def backward_D_basic(self, netD, real, fake): loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 - loss_D.backward() + if self.opt.apex: + with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled: + loss_D_scaled.backward() + else: + loss_D.backward() + return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1) def backward_G(self): """Calculate the loss for generators G_A and G_B""" @@ -175,7 +195,13 @@ def backward_G(self): self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 939eb887ee3..55692479918 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,7 +1,14 @@ import torch +from torch.utils.checkpoint import checkpoint + from .base_model import BaseModel from . import networks +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. @@ -70,6 +77,10 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -85,7 +96,10 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" - self.fake_B = self.netG(self.real_A) # G(A) + if not self.opt.checkpointing: + self.fake_B = self.netG(self.real_A) # G(A) + else: + self.fake_B = checkpoint(self.netG, self.real_A) def backward_D(self): """Calculate GAN loss for the discriminator""" @@ -99,7 +113,12 @@ def backward_D(self): self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" @@ -111,7 +130,12 @@ def backward_G(self): self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) diff --git a/options/base_options.py b/options/base_options.py index afb5d0852d1..dbed281be55 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -54,6 +54,11 @@ def initialize(self, parser): parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--checkpointing', default=False, type=bool, + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') + self.initialized = True return parser @@ -114,7 +119,7 @@ def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test - + opt.apex = opt.opt_level != "O0" # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' From 653b5781080f09c6824b90e661bf4c99485b7331 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 9 Jul 2020 15:15:27 +0300 Subject: [PATCH 2/6] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Fix data_parallel order --- models/base_model.py | 8 ++++++++ models/cycle_gan_model.py | 3 +++ models/networks.py | 1 - models/pix2pix_model.py | 2 ++ models/template_model.py | 2 ++ options/base_options.py | 2 +- 6 files changed, 16 insertions(+), 2 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 9cfb761897e..ba7ce376fb9 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -227,3 +227,11 @@ def set_requires_grad(self, nets, requires_grad=False): if net is not None: for param in net.parameters(): param.requires_grad = requires_grad + + def make_data_parallel(self): + """Make models data parallel""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs + setattr(self, 'net' + name, net) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index f536503c8f8..69b21059adb 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -105,6 +105,9 @@ def __init__(self, opt): [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + # need to be wrapped after amp.initialize + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/models/networks.py b/models/networks.py index b3a10c99c20..1bd134e5186 100644 --- a/models/networks.py +++ b/models/networks.py @@ -111,7 +111,6 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) - net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 55692479918..f0fdf9c31e1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -81,6 +81,8 @@ def __init__(self, opt): [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/models/template_model.py b/models/template_model.py index 68cdaf6a9a2..45d3659ca4c 100644 --- a/models/template_model.py +++ b/models/template_model.py @@ -67,6 +67,8 @@ def __init__(self, opt): self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optimizer] + # need to be wrapped after amp.initialize + self.make_data_parallel() # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): diff --git a/options/base_options.py b/options/base_options.py index dbed281be55..7f0b91ab4b2 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -55,7 +55,7 @@ def initialize(self, parser): parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - parser.add_argument('--checkpointing', default=False, type=bool, + parser.add_argument('--checkpointing', action='store_true', help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') From 48ef29feaca5a181f0bcf0b10b2eb49f58057fd3 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Fri, 10 Jul 2020 18:09:28 +0300 Subject: [PATCH 3/6] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Disable checkpointing for pix2pix --- models/cycle_gan_model.py | 4 ++-- models/pix2pix_model.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 69b21059adb..8db18f2f3ff 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -124,12 +124,12 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - if not self.opt.checkpointing: + if not self.opt.checkpointing or not self.isTrain: self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) else: self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - if not self.opt.checkpointing: + if not self.opt.checkpointing or not self.isTrain: self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) else: self.rec_B = checkpoint(self.netG_A, self.fake_A) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index f0fdf9c31e1..70b18a4de8c 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -98,10 +98,7 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" - if not self.opt.checkpointing: - self.fake_B = self.netG(self.real_A) # G(A) - else: - self.fake_B = checkpoint(self.netG, self.real_A) + self.fake_B = self.netG(self.real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" From 60961635b21138dc36c23d8095bb866a74f53cb8 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Fri, 10 Jul 2020 18:18:18 +0300 Subject: [PATCH 4/6] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Minor fix --- models/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/test_model.py b/models/test_model.py index fe15f40176e..a9ab50064db 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -48,6 +48,7 @@ def __init__(self, opt): # assigns the model to self.netG_[suffix] so that it can be loaded # please see setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + self.make_data_parallel() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. From d2a66809f8bf38229eec8965952e375fa5bef44d Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Sat, 11 Jul 2020 00:40:28 +0300 Subject: [PATCH 5/6] Add NVIDIA apex support and checkpointing memory optimization (https://pytorch.org/docs/stable/checkpoint.html) Refactor configs --- models/cycle_gan_model.py | 4 ++-- options/base_options.py | 7 +------ options/train_options.py | 9 +++++++++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 8db18f2f3ff..63b980d915d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -124,12 +124,12 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - if not self.opt.checkpointing or not self.isTrain: + if not self.isTrain or not self.opt.checkpointing: self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) else: self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - if not self.opt.checkpointing or not self.isTrain: + if not self.isTrain or not self.opt.checkpointing: self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) else: self.rec_B = checkpoint(self.netG_A, self.fake_A) diff --git a/options/base_options.py b/options/base_options.py index 7f0b91ab4b2..afb5d0852d1 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -54,11 +54,6 @@ def initialize(self, parser): parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - - parser.add_argument('--checkpointing', action='store_true', - help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') - parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') - self.initialized = True return parser @@ -119,7 +114,7 @@ def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test - opt.apex = opt.opt_level != "O0" + # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' diff --git a/options/train_options.py b/options/train_options.py index c8d5d2a92a9..575cd3cbbf5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -36,5 +36,14 @@ def initialize(self, parser): parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + # training optimizations + parser.add_argument('--checkpointing', action='store_true', + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') self.isTrain = True return parser + + def parse(self): + opt = BaseOptions.parse(self) + opt.apex = opt.opt_level != "O0" + return opt From bfa902ab49ef855aca42b8e3fe88e56dcb86e461 Mon Sep 17 00:00:00 2001 From: seovchinnikov Date: Thu, 16 Jul 2020 16:04:21 +0300 Subject: [PATCH 6/6] Fix CPU version --- models/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/base_model.py b/models/base_model.py index ba7ce376fb9..a387ec27173 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -230,6 +230,8 @@ def set_requires_grad(self, nets, requires_grad=False): def make_data_parallel(self): """Make models data parallel""" + if len(self.gpu_ids) == 0: + return for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name)