Skip to content

Commit daf375b

Browse files
authored
Add files via upload
1 parent abd974c commit daf375b

38 files changed

+642
-0
lines changed
1.82 KB
Binary file not shown.

__pycache__/parameter.cpython-36.pyc

1.69 KB
Binary file not shown.
4.38 KB
Binary file not shown.

__pycache__/spectral.cpython-36.pyc

2.44 KB
Binary file not shown.

__pycache__/trainer.cpython-36.pyc

5.24 KB
Binary file not shown.

__pycache__/utils.cpython-36.pyc

978 Bytes
Binary file not shown.

data_loader.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torchvision.datasets as dsets
3+
from torchvision import transforms
4+
5+
6+
class Data_Loader():
7+
def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True):
8+
self.dataset = dataset
9+
self.path = image_path
10+
self.imsize = image_size
11+
self.batch = batch_size
12+
self.shuf = shuf
13+
self.train = train
14+
15+
def transform(self, resize, totensor, normalize, centercrop):
16+
options = []
17+
if centercrop:
18+
options.append(transforms.CenterCrop(160))
19+
if resize:
20+
options.append(transforms.Resize((self.imsize,self.imsize)))
21+
if totensor:
22+
options.append(transforms.ToTensor())
23+
if normalize:
24+
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
25+
transform = transforms.Compose(options)
26+
return transform
27+
28+
def load_lsun(self, classes='church_outdoor_train'):
29+
transforms = self.transform(True, True, True, False)
30+
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
31+
return dataset
32+
33+
def load_celeb(self):
34+
transforms = self.transform(True, True, True, True)
35+
dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms)
36+
return dataset
37+
38+
39+
def loader(self):
40+
if self.dataset == 'lsun':
41+
dataset = self.load_lsun()
42+
elif self.dataset == 'celeb':
43+
dataset = self.load_celeb()
44+
45+
loader = torch.utils.data.DataLoader(dataset=dataset,
46+
batch_size=self.batch,
47+
shuffle=self.shuf,
48+
num_workers=2,
49+
drop_last=True)
50+
return loader
51+

download.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
FILE=$1
2+
3+
if [ $FILE == 'CelebA' ]
4+
then
5+
URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0
6+
ZIP_FILE=./data/CelebA.zip
7+
8+
elif [ $FILE == 'LSUN' ]
9+
then
10+
URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0
11+
ZIP_FILE=./data/church_outdoor_train_lmdb.zip
12+
else
13+
echo "Available datasets are: CelebA and LSUN"
14+
exit 1
15+
fi
16+
17+
mkdir -p ./data/
18+
wget -N $URL -O $ZIP_FILE
19+
unzip $ZIP_FILE -d ./data/
20+
21+
if [ $FILE == 'CelebA' ]
22+
then
23+
mv ./data/CelebA_nocrop ./data/CelebA
24+
fi
25+
26+
rm $ZIP_FILE

image/attn_gf1.png

28.4 KB
Loading

image/attn_gf2.png

97.7 KB
Loading

image/main_model.PNG

107 KB
Loading

image/sagan_attn.png

385 KB
Loading

image/sagan_celeb.png

2 MB
Loading

image/sagan_lsun.png

1.9 MB
Loading

image/unnamed

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

main.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
from parameter import *
3+
from trainer import Trainer
4+
# from tester import Tester
5+
from data_loader import Data_Loader
6+
#from torch.backends import cudnn
7+
from utils import make_folder
8+
9+
def main(config):
10+
# For fast training
11+
#cudnn.benchmark = True
12+
13+
14+
# Data loader
15+
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
16+
config.batch_size, shuf=config.train)
17+
18+
# Create directories if not exist
19+
make_folder(config.model_save_path, config.version)
20+
make_folder(config.sample_path, config.version)
21+
make_folder(config.log_path, config.version)
22+
make_folder(config.attn_path, config.version)
23+
24+
25+
if config.train:
26+
if config.model=='sagan':
27+
trainer = Trainer(data_loader.loader(), config)
28+
elif config.model == 'qgan':
29+
trainer = qgan_trainer(data_loader.loader(), config)
30+
trainer.train()
31+
else:
32+
tester = Tester(data_loader.loader(), config)
33+
tester.test()
34+
35+
if __name__ == '__main__':
36+
config = get_parameters()
37+
print(config)
38+
main(config)

parameter.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
3+
def str2bool(v):
4+
return v.lower() in ('true')
5+
6+
def get_parameters():
7+
8+
parser = argparse.ArgumentParser()
9+
10+
# Model hyper-parameters
11+
parser.add_argument('--model', type=str, default='sagan', choices=['sagan', 'qgan'])
12+
parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge'])
13+
parser.add_argument('--imsize', type=int, default=32)
14+
parser.add_argument('--g_num', type=int, default=5)
15+
parser.add_argument('--z_dim', type=int, default=128)
16+
parser.add_argument('--g_conv_dim', type=int, default=64)
17+
parser.add_argument('--d_conv_dim', type=int, default=64)
18+
parser.add_argument('--lambda_gp', type=float, default=10)
19+
parser.add_argument('--version', type=str, default='sagan_1')
20+
21+
# Training setting
22+
parser.add_argument('--total_step', type=int, default=100, help='how many times to update the generator')
23+
parser.add_argument('--d_iters', type=float, default=5)
24+
parser.add_argument('--batch_size', type=int, default=64)
25+
parser.add_argument('--num_workers', type=int, default=2)
26+
parser.add_argument('--g_lr', type=float, default=0.0001)
27+
parser.add_argument('--d_lr', type=float, default=0.0004)
28+
parser.add_argument('--lr_decay', type=float, default=0.95)
29+
parser.add_argument('--beta1', type=float, default=0.0)
30+
parser.add_argument('--beta2', type=float, default=0.9)
31+
32+
# using pretrained
33+
parser.add_argument('--pretrained_model', type=int, default=None)
34+
35+
# Misc
36+
parser.add_argument('--train', type=str2bool, default=True)
37+
parser.add_argument('--parallel', type=str2bool, default=False)
38+
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb'])
39+
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
40+
41+
# Path
42+
parser.add_argument('--image_path', type=str, default='./data')
43+
parser.add_argument('--log_path', type=str, default='./logs')
44+
parser.add_argument('--model_save_path', type=str, default='./models')
45+
parser.add_argument('--sample_path', type=str, default='./samples')
46+
parser.add_argument('--attn_path', type=str, default='./attn')
47+
48+
# Step size
49+
parser.add_argument('--log_step', type=int, default=10)
50+
parser.add_argument('--sample_step', type=int, default=10)
51+
parser.add_argument('--model_save_step', type=float, default=1.0)
52+
53+
54+
return parser.parse_args()

sagan_models.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
from spectral import SpectralNorm
6+
import numpy as np
7+
8+
class Self_Attn(nn.Module):
9+
""" Self attention Layer"""
10+
def __init__(self,in_dim,activation):
11+
super(Self_Attn,self).__init__()
12+
self.chanel_in = in_dim
13+
self.activation = activation
14+
15+
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
16+
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
17+
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
18+
self.gamma = nn.Parameter(torch.zeros(1))
19+
20+
self.softmax = nn.Softmax(dim=-1) #
21+
def forward(self,x):
22+
"""
23+
inputs :
24+
x : input feature maps( B X C X W X H)
25+
returns :
26+
out : self attention value + input feature
27+
attention: B X N X N (N is Width*Height)
28+
"""
29+
m_batchsize,C,width ,height = x.size()
30+
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
31+
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
32+
energy = torch.bmm(proj_query,proj_key) # transpose check
33+
attention = self.softmax(energy) # BX (N) X (N)
34+
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
35+
36+
out = torch.bmm(proj_value,attention.permute(0,2,1) )
37+
out = out.view(m_batchsize,C,width,height)
38+
39+
out = self.gamma*out + x
40+
return out,attention
41+
42+
class Generator(nn.Module):
43+
"""Generator."""
44+
45+
def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
46+
super(Generator, self).__init__()
47+
self.imsize = image_size
48+
layer1 = []
49+
layer2 = []
50+
layer3 = []
51+
last = []
52+
53+
repeat_num = int(np.log2(self.imsize)) - 3
54+
mult = 2 ** repeat_num # 8
55+
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
56+
layer1.append(nn.BatchNorm2d(conv_dim * mult))
57+
layer1.append(nn.ReLU())
58+
59+
curr_dim = conv_dim * mult
60+
61+
layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
62+
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
63+
layer2.append(nn.ReLU())
64+
65+
curr_dim = int(curr_dim / 2)
66+
67+
layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
68+
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
69+
layer3.append(nn.ReLU())
70+
71+
if self.imsize == 64:
72+
layer4 = []
73+
curr_dim = int(curr_dim / 2)
74+
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
75+
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
76+
layer4.append(nn.ReLU())
77+
self.l4 = nn.Sequential(*layer4)
78+
curr_dim = int(curr_dim / 2)
79+
80+
self.l1 = nn.Sequential(*layer1)
81+
self.l2 = nn.Sequential(*layer2)
82+
self.l3 = nn.Sequential(*layer3)
83+
84+
last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
85+
last.append(nn.Tanh())
86+
self.last = nn.Sequential(*last)
87+
88+
self.attn1 = Self_Attn( 128, 'relu')
89+
self.attn2 = Self_Attn( 64, 'relu')
90+
91+
def forward(self, z):
92+
z = z.view(z.size(0), z.size(1), 1, 1)
93+
out=self.l1(z)
94+
out=self.l2(out)
95+
out=self.l3(out)
96+
out,p1 = self.attn1(out)
97+
out=self.l4(out)
98+
out,p2 = self.attn2(out)
99+
out=self.last(out)
100+
101+
return out, p1, p2
102+
103+
104+
class Discriminator(nn.Module):
105+
"""Discriminator, Auxiliary Classifier."""
106+
107+
def __init__(self, batch_size=64, image_size=64, conv_dim=64):
108+
super(Discriminator, self).__init__()
109+
self.imsize = image_size
110+
layer1 = []
111+
layer2 = []
112+
layer3 = []
113+
last = []
114+
115+
layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
116+
layer1.append(nn.LeakyReLU(0.1))
117+
118+
curr_dim = conv_dim
119+
120+
layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
121+
layer2.append(nn.LeakyReLU(0.1))
122+
curr_dim = curr_dim * 2
123+
124+
layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
125+
layer3.append(nn.LeakyReLU(0.1))
126+
curr_dim = curr_dim * 2
127+
128+
if self.imsize == 64:
129+
layer4 = []
130+
layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
131+
layer4.append(nn.LeakyReLU(0.1))
132+
self.l4 = nn.Sequential(*layer4)
133+
curr_dim = curr_dim*2
134+
self.l1 = nn.Sequential(*layer1)
135+
self.l2 = nn.Sequential(*layer2)
136+
self.l3 = nn.Sequential(*layer3)
137+
138+
last.append(nn.Conv2d(curr_dim, 1, 4))
139+
self.last = nn.Sequential(*last)
140+
141+
self.attn1 = Self_Attn(256, 'relu')
142+
self.attn2 = Self_Attn(512, 'relu')
143+
144+
def forward(self, x):
145+
out = self.l1(x)
146+
out = self.l2(out)
147+
out = self.l3(out)
148+
out,p1 = self.attn1(out)
149+
out=self.l4(out)
150+
out,p2 = self.attn2(out)
151+
out=self.last(out)
152+
153+
return out.squeeze(), p1, p2

samples/sagan_celeb/100_fake.png

568 KB
Loading

samples/sagan_celeb/10_fake.png

626 KB
Loading

samples/sagan_celeb/200_fake.png

586 KB
Loading

samples/sagan_celeb/20_fake.png

639 KB
Loading

samples/sagan_celeb/300_fake.png

567 KB
Loading

samples/sagan_celeb/30_fake.png

647 KB
Loading

samples/sagan_celeb/400_fake.png

586 KB
Loading

samples/sagan_celeb/40_fake.png

648 KB
Loading

samples/sagan_celeb/500_fake.png

611 KB
Loading

samples/sagan_celeb/50_fake.png

641 KB
Loading

samples/sagan_celeb/600_fake.png

575 KB
Loading

samples/sagan_celeb/60_fake.png

641 KB
Loading

samples/sagan_celeb/700_fake.png

581 KB
Loading

samples/sagan_celeb/70_fake.png

618 KB
Loading

samples/sagan_celeb/800_fake.png

567 KB
Loading

samples/sagan_celeb/80_fake.png

616 KB
Loading

samples/sagan_celeb/90_fake.png

596 KB
Loading

0 commit comments

Comments
 (0)