diff --git a/models/base_model.py b/models/base_model.py index 9cfb761897e..a387ec27173 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -227,3 +227,13 @@ def set_requires_grad(self, nets, requires_grad=False): if net is not None: for param in net.parameters(): param.requires_grad = requires_grad + + def make_data_parallel(self): + """Make models data parallel""" + if len(self.gpu_ids) == 0: + return + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs + setattr(self, 'net' + name, net) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 15bb72d8ddc..63b980d915d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -3,7 +3,12 @@ from util.image_pool import ImagePool from .base_model import BaseModel from . import networks +from torch.utils.checkpoint import checkpoint +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") class CycleGANModel(BaseModel): """ @@ -96,6 +101,13 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3) + + # need to be wrapped after amp.initialize + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -112,11 +124,17 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG_A(self.real_A) # G_A(A) - self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + if not self.isTrain or not self.opt.checkpointing: + self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + else: + self.rec_A = checkpoint(self.netG_B, self.fake_B) self.fake_A = self.netG_B(self.real_B) # G_B(B) - self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + if not self.isTrain or not self.opt.checkpointing: + self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) + else: + self.rec_B = checkpoint(self.netG_A, self.fake_A) - def backward_D_basic(self, netD, real, fake): + def backward_D_basic(self, netD, real, fake, loss_id): """Calculate GAN loss for the discriminator Parameters: @@ -135,18 +153,23 @@ def backward_D_basic(self, netD, real, fake): loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 - loss_D.backward() + if self.opt.apex: + with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled: + loss_D_scaled.backward() + else: + loss_D.backward() + return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1) def backward_G(self): """Calculate the loss for generators G_A and G_B""" @@ -175,7 +198,13 @@ def backward_G(self): self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() + def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" diff --git a/models/networks.py b/models/networks.py index b3a10c99c20..1bd134e5186 100644 --- a/models/networks.py +++ b/models/networks.py @@ -111,7 +111,6 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) - net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 939eb887ee3..70b18a4de8c 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,7 +1,14 @@ import torch +from torch.utils.checkpoint import checkpoint + from .base_model import BaseModel from . import networks +try: + from apex import amp +except ImportError: + print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level") + class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. @@ -70,6 +77,12 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) + if opt.apex: + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize( + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2) + + self.make_data_parallel() + def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -99,7 +112,12 @@ def backward_D(self): self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled: + loss_D_scaled.backward() + else: + self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" @@ -111,7 +129,12 @@ def backward_G(self): self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 - self.loss_G.backward() + + if self.opt.apex: + with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=1) as loss_G_scaled: + loss_G_scaled.backward() + else: + self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) diff --git a/models/template_model.py b/models/template_model.py index 68cdaf6a9a2..45d3659ca4c 100644 --- a/models/template_model.py +++ b/models/template_model.py @@ -67,6 +67,8 @@ def __init__(self, opt): self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optimizer] + # need to be wrapped after amp.initialize + self.make_data_parallel() # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): diff --git a/models/test_model.py b/models/test_model.py index fe15f40176e..a9ab50064db 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -48,6 +48,7 @@ def __init__(self, opt): # assigns the model to self.netG_[suffix] so that it can be loaded # please see setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + self.make_data_parallel() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. diff --git a/options/train_options.py b/options/train_options.py index c8d5d2a92a9..575cd3cbbf5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -36,5 +36,14 @@ def initialize(self, parser): parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + # training optimizations + parser.add_argument('--checkpointing', action='store_true', + help='if true, it applies gradient checkpointing, saves memory but it makes the training slower') + parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training') self.isTrain = True return parser + + def parse(self): + opt = BaseOptions.parse(self) + opt.apex = opt.opt_level != "O0" + return opt