Skip to content

Commit 391fed5

Browse files
authored
Add files via upload
1 parent 2d0e99f commit 391fed5

13 files changed

+2624
-0
lines changed

celeba_filenames_test.pickle

80.7 KB
Binary file not shown.

celeba_filenames_train.pickle

347 KB
Binary file not shown.

dataloader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
2+
3+
def get_dataloader(args, dataset, is_train=True):
4+
if is_train:
5+
sampler = RandomSampler(dataset)
6+
else:
7+
sampler = SequentialSampler(dataset)
8+
9+
dataloader = DataLoader(dataset=dataset, sampler=sampler, batch_size=args.batch_size)
10+
return dataloader

fix_seed.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import numpy as np
3+
import random
4+
5+
def seed_fix(int):
6+
# PyTorch
7+
torch.manual_seed(int)
8+
torch.cuda.manual_seed(int)
9+
torch.cuda.manual_seed_all(int) # for multi-GPU
10+
11+
# CuDNN
12+
torch.backends.cudnn.deterministic = True
13+
torch.backends.cudnn.benchmark = False
14+
15+
# Numpy
16+
np.random.seed(int)
17+
18+
# Random
19+
random.seed(int)

generate_image.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import warnings
2+
warnings.filterwarnings(action="ignore")
3+
4+
import argparse
5+
import click
6+
import os
7+
import torch
8+
import torchvision
9+
import clip
10+
from fix_seed import seed_fix
11+
from pathlib import Path
12+
from network import Generator, Discriminator
13+
from train_utils import *
14+
15+
16+
@torch.no_grad()
17+
def main():
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument('--prompt', type=str, required=True)
20+
parser.add_argument('--load_epoch', type=int, required=True)
21+
parser.add_argument('--checkpoint_path', type=Path, required=True)
22+
parser.add_argument('--show_hyp', action='store_true')
23+
parser.add_argument('--clip_model', type=click.Choice(['B/32', 'L/14', 'B/16']), default='B/32')
24+
args = parser.parse_args()
25+
26+
# seed_fix(40)
27+
hyp = torch.load(os.path.join(args.checkpoint_path, f"hyperparameter.pt"), map_location='cpu')
28+
if args.show_hyp:
29+
print_hyp(hyp)
30+
31+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32+
clip_model, _ = clip.load(f"ViT-{args.clip_model}", device=device)
33+
G = Generator(hyp['clip_embedding_dim'], hyp['projection_dim'], hyp['noise_dim'], hyp['g_in_chans'], hyp['g_out_chans'], hyp['num_stage'], device).to(device)
34+
D_lst = [
35+
Discriminator(hyp['projection_dim'], hyp['g_out_chans'], hyp['d_in_chans'], hyp['d_out_chans'], hyp['clip_embedding_dim'], curr_stage, device).to(device)
36+
for curr_stage in range(hyp['num_stage'])
37+
]
38+
load_checkpoint(G, D_lst, args.checkpoint_path, args.load_epoch)
39+
40+
41+
prompt = clip.tokenize([args.prompt]).to(device)
42+
txt_feature = clip_model.encode_text(prompt)
43+
z = torch.randn(txt_feature.shape[0], hyp['noise_dim']).to(device)
44+
txt_feature = normalize(txt_feature.to(device)).type(torch.float32)
45+
46+
fake_images, _, _ = G(txt_feature, z)
47+
fake_image_64 = denormalize_image(fake_images[-3].detach().cpu())
48+
fake_image_128 = denormalize_image(fake_images[-2].detach().cpu())
49+
fake_image_256 = denormalize_image(fake_images[-1].detach().cpu())
50+
# epoch_ret = torchvision.utils.make_grid(fake_image, padding=2, normalize=True)
51+
torchvision.utils.save_image(fake_image_64, "result_64.png")
52+
torchvision.utils.save_image(fake_image_128, "result_128.png")
53+
torchvision.utils.save_image(fake_image_256, "result_256.png")
54+
55+
56+
if __name__ == "__main__":
57+
main()

main.ipynb

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"id": "-ksoUw_9uFDY"
8+
},
9+
"outputs": [],
10+
"source": [
11+
"#@title Setup (Do NOT modify)\n",
12+
"from google.colab import drive\n",
13+
"drive.mount('/content/drive')\n",
14+
"%cd /content/drive/MyDrive/final\n",
15+
"!pip install openai-clip"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"metadata": {
22+
"id": "ZSaektzjZbKG"
23+
},
24+
"outputs": [],
25+
"source": [
26+
"#@title Data preprocessing (Train)\n",
27+
"!python preproc_datasets_celeba_zip_train.py --source=./multimodal_celeba_hq.zip --dest train_data_6cap.zip --emb_dim 512 --transform=center-crop --width=256 --height=256"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": null,
33+
"metadata": {
34+
"id": "Svp2gXHa4UDb"
35+
},
36+
"outputs": [],
37+
"source": [
38+
"#@title Data preprocessing (Test)\n",
39+
"!python preproc_datasets_celeba_zip_test.py --source=./multimodal_celeba_hq.zip --dest test_data_6cap.zip --emb_dim 512 --transform=center-crop --width=256 --height=256"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {
46+
"cellView": "form",
47+
"id": "pQY41tvBuJVO"
48+
},
49+
"outputs": [],
50+
"source": [
51+
"#@title Train\n",
52+
"\n",
53+
"train_data = \"sample_train.zip\"#@param {\"type\": \"string\"}\n",
54+
"batch_size = 24 #@param {\"type\": \"integer\"}\n",
55+
"num_epochs = 10 #@param {\"type\": \"integer\"}\n",
56+
"learning_rate = 0.0002 #@param {\"type\": \"number\"}\n",
57+
"report_interval = 50 #@param {\"type\": \"integer\"}\n",
58+
"noise_dim = 100 #@param {\"type\": \"integer\"}\n",
59+
"projection_dim = 128 #@param {\"type\": \"integer\"}\n",
60+
"clip_embedding_dim = 512 #@param {\"type\": \"integer\"}\n",
61+
"checkpoint_path = \"model_exp1\" #@param {\"type\": \"string\"}\n",
62+
"result_path = \"images_exp1\" #@param {\"type\": \"string\"}\n",
63+
"use_uncond_loss = True #@param {\"type\": \"boolean\"}\n",
64+
"use_contrastive_loss = True #@param {\"type\": \"boolean\"}\n",
65+
"num_stage = 3 #@param {\"type\": \"integer\"}\n",
66+
"resume_checkpoint_path = \"None\" #@param {\"type\": \"string\"}\n",
67+
"resume_epoch = -1 #@param {\"type\": \"integer\"}\n",
68+
"\n",
69+
"test_cmd = f'''python main.py \\\n",
70+
" --train_data \"{train_data}\" \\\n",
71+
" --batch_size {batch_size} \\\n",
72+
" --num_epochs {num_epochs} \\\n",
73+
" --learning_rate {learning_rate} \\\n",
74+
" --report_interval {report_interval} \\\n",
75+
" --noise_dim {noise_dim} \\\n",
76+
" --projection_dim {projection_dim} \\\n",
77+
" --clip_embedding_dim {clip_embedding_dim} \\\n",
78+
" --checkpoint_path \"{checkpoint_path}\" \\\n",
79+
" --result_path \"{result_path}\" \\\n",
80+
" --num_stage {num_stage} \\\n",
81+
" --resume_epoch {resume_epoch} \\\n",
82+
" '''\n",
83+
"if use_uncond_loss:\n",
84+
" test_cmd += \"--use_uncond_loss \"\n",
85+
"if use_contrastive_loss:\n",
86+
" test_cmd += \"--use_contrastive_loss \"\n",
87+
"if resume_checkpoint_path != \"None\":\n",
88+
" test_cmd += f'''--resume_checkpoint_path \"{resume_checkpoint_path}\"'''\n",
89+
"\n",
90+
"with open('./train_script.sh', 'w') as file:\n",
91+
" file.write(test_cmd)\n",
92+
"\n",
93+
"!bash train_script.sh"
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"metadata": {
100+
"cellView": "form",
101+
"id": "S7ueWJp1t-Zi"
102+
},
103+
"outputs": [],
104+
"source": [
105+
"#@title Test (Generate image)\n",
106+
"\n",
107+
"prompt = \"The woman is young and has blond hair, and arched eyebrows.\"#@param {\"type\": \"string\"}\n",
108+
"load_epoch = 10 #@param {\"type\": \"integer\"}\n",
109+
"checkpoint_path = \"model_exp1\" #@param {\"type\": \"string\"}\n",
110+
"\n",
111+
"test_cmd = f'''python generate_image.py \\\n",
112+
" --prompt \"{prompt}\" \\\n",
113+
" --load_epoch {load_epoch} \\\n",
114+
" --checkpoint_path \"{checkpoint_path}\"\n",
115+
" '''\n",
116+
"\n",
117+
"with open('./test_script.sh', 'w') as file:\n",
118+
" file.write(test_cmd)\n",
119+
"\n",
120+
"!bash test_script.sh\n",
121+
"\n",
122+
"\n",
123+
"from IPython.display import Image\n",
124+
"import os\n",
125+
"img_64 = Image(os.path.join(\"result_64.png\"))\n",
126+
"display(img_64)\n",
127+
"img_128 = Image(os.path.join(\"result_128.png\"))\n",
128+
"display(img_128)\n",
129+
"img_256 = Image(os.path.join(\"result_256.png\"))\n",
130+
"display(img_256)"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": null,
136+
"metadata": {
137+
"id": "tH5gJKdVyUuu"
138+
},
139+
"outputs": [],
140+
"source": []
141+
}
142+
],
143+
"metadata": {
144+
"accelerator": "GPU",
145+
"colab": {
146+
"gpuType": "T4",
147+
"machine_shape": "hm",
148+
"provenance": []
149+
},
150+
"kernelspec": {
151+
"display_name": "Python 3",
152+
"name": "python3"
153+
},
154+
"language_info": {
155+
"name": "python"
156+
}
157+
},
158+
"nbformat": 4,
159+
"nbformat_minor": 0
160+
}

main.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import warnings
2+
warnings.filterwarnings(action="ignore")
3+
4+
import argparse
5+
import click
6+
import os
7+
from fix_seed import seed_fix
8+
from train import train
9+
from pathlib import Path
10+
from typing import Union
11+
12+
def main():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--train_data', default='sample_train.zip', type=Path, help="path of directory containing training dataset")
15+
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
16+
parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs')
17+
parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate')
18+
parser.add_argument('--report_interval', type=int, default=100, help='Report interval')
19+
parser.add_argument('--noise_dim', type=int, default=100, help= 'Input noise dimension to Generator')
20+
parser.add_argument('--projection_dim', type=int, default=128, help= 'Noise projection dimension')
21+
parser.add_argument('--clip_embedding_dim', type=int, default=512, help= 'CLIP embedding vector dimension')
22+
parser.add_argument('--checkpoint_path', type=Path, default='model_exp1', help='Checkpoint path')
23+
parser.add_argument('--result_path', type=Path, default='images_exp1', help='Generated image path')
24+
parser.add_argument('--use_uncond_loss', action="store_true")
25+
parser.add_argument('--use_contrastive_loss', action="store_true")
26+
parser.add_argument('--num_stage', type=int, default=1)
27+
parser.add_argument('--resume_checkpoint_path', default=None)
28+
parser.add_argument('--resume_epoch', type=int, default=-1)
29+
args = parser.parse_args()
30+
31+
os.makedirs(args.checkpoint_path, exist_ok=True)
32+
os.makedirs(args.result_path, exist_ok=True)
33+
34+
seed_fix(0)
35+
train(args)
36+
37+
if __name__ == "__main__":
38+
main()

0 commit comments

Comments
 (0)