Skip to content

Commit cb8392a

Browse files
committed
replicate demo
1 parent 6328174 commit cb8392a

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ A CVPR 2022 (ORAL) paper ([paper](https://openaccess.thecvf.com/content/CVPR2022
1515

1616
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1OTfwkklN-IEd4hFk4LnweOleyDtS4XTh/view?usp=sharing)
1717

18+
Try web demo and API here [![Replicate](https://replicate.com/cjwbw/diffae/badge)](https://replicate.com/cjwbw/diffae)
19+
1820
Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
1921

2022
### Prerequisites

cog.yaml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
build:
2+
cuda: "10.2"
3+
gpu: true
4+
python_version: "3.8"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
python_packages:
9+
- "numpy==1.21.5"
10+
- "cmake==3.23.3"
11+
- "ipython==7.21.0"
12+
- "opencv-python==4.5.4.58"
13+
- "pandas==1.1.5"
14+
- "lmdb==1.2.1"
15+
- "lpips==0.1.4"
16+
- "pytorch-fid==0.2.0"
17+
- "ftfy==6.1.1"
18+
- "scipy==1.5.4"
19+
- "torch==1.9.1"
20+
- "torchvision==0.10.1"
21+
- "tqdm==4.62.3"
22+
- "regex==2022.7.25"
23+
- "Pillow==9.2.0"
24+
- "pytorch_lightning==1.7.0"
25+
26+
run:
27+
- pip install dlib
28+
29+
predict: "predict.py:Predictor"

predict.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
2+
# wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
3+
# bunzip2 shape_predictor_68_face_landmarks.dat.bz2
4+
5+
import os
6+
import torch
7+
from torchvision.utils import save_image
8+
import tempfile
9+
from templates import *
10+
from templates_cls import *
11+
from experiment_classifier import ClsModel
12+
from align import LandmarksDetector, image_align
13+
from cog import BasePredictor, Path, Input, BaseModel
14+
15+
16+
class ModelOutput(BaseModel):
17+
image: Path
18+
19+
20+
class Predictor(BasePredictor):
21+
def setup(self):
22+
self.aligned_dir = "aligned"
23+
os.makedirs(self.aligned_dir, exist_ok=True)
24+
self.device = "cuda:0"
25+
26+
# Model Initialization
27+
model_config = ffhq256_autoenc()
28+
self.model = LitModel(model_config)
29+
state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
30+
self.model.load_state_dict(state["state_dict"], strict=False)
31+
self.model.ema_model.eval()
32+
self.model.ema_model.to(self.device)
33+
34+
# Classifier Initialization
35+
classifier_config = ffhq256_autoenc_cls()
36+
classifier_config.pretrain = None # a bit faster
37+
self.classifier = ClsModel(classifier_config)
38+
state_class = torch.load(
39+
"checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
40+
)
41+
print("latent step:", state_class["global_step"])
42+
self.classifier.load_state_dict(state_class["state_dict"], strict=False)
43+
self.classifier.to(self.device)
44+
45+
self.landmarks_detector = LandmarksDetector(
46+
"shape_predictor_68_face_landmarks.dat"
47+
)
48+
49+
def predict(
50+
self,
51+
image: Path = Input(
52+
description="Input image for face manipulation. Image will be aligned and cropped, "
53+
"output aligned and manipulated images.",
54+
),
55+
target_class: str = Input(
56+
default="Bangs",
57+
choices=[
58+
"5_o_Clock_Shadow",
59+
"Arched_Eyebrows",
60+
"Attractive",
61+
"Bags_Under_Eyes",
62+
"Bald",
63+
"Bangs",
64+
"Big_Lips",
65+
"Big_Nose",
66+
"Black_Hair",
67+
"Blond_Hair",
68+
"Blurry",
69+
"Brown_Hair",
70+
"Bushy_Eyebrows",
71+
"Chubby",
72+
"Double_Chin",
73+
"Eyeglasses",
74+
"Goatee",
75+
"Gray_Hair",
76+
"Heavy_Makeup",
77+
"High_Cheekbones",
78+
"Male",
79+
"Mouth_Slightly_Open",
80+
"Mustache",
81+
"Narrow_Eyes",
82+
"Beard",
83+
"Oval_Face",
84+
"Pale_Skin",
85+
"Pointy_Nose",
86+
"Receding_Hairline",
87+
"Rosy_Cheeks",
88+
"Sideburns",
89+
"Smiling",
90+
"Straight_Hair",
91+
"Wavy_Hair",
92+
"Wearing_Earrings",
93+
"Wearing_Hat",
94+
"Wearing_Lipstick",
95+
"Wearing_Necklace",
96+
"Wearing_Necktie",
97+
"Young",
98+
],
99+
description="Choose manipulation direction.",
100+
),
101+
manipulation_amplitude: float = Input(
102+
default=0.3,
103+
ge=-0.5,
104+
le=0.5,
105+
description="When set too strong it would result in artifact as it could dominate the original image information.",
106+
),
107+
T_step: int = Input(
108+
default=100,
109+
choices=[50, 100, 125, 200, 250, 500],
110+
description="Number of step for generation.",
111+
),
112+
T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
113+
) -> List[ModelOutput]:
114+
115+
img_size = 256
116+
print("Aligning image...")
117+
for i, face_landmarks in enumerate(
118+
self.landmarks_detector.get_landmarks(str(image)), start=1
119+
):
120+
image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
121+
122+
data = ImageDataset(
123+
self.aligned_dir,
124+
image_size=img_size,
125+
exts=["jpg", "jpeg", "JPG", "png"],
126+
do_augment=False,
127+
)
128+
129+
print("Encoding and Manipulating the aligned image...")
130+
cls_manipulation_amplitude = manipulation_amplitude
131+
interpreted_target_class = target_class
132+
if (
133+
target_class not in CelebAttrDataset.id_to_cls
134+
and f"No_{target_class}" in CelebAttrDataset.id_to_cls
135+
):
136+
cls_manipulation_amplitude = -manipulation_amplitude
137+
interpreted_target_class = f"No_{target_class}"
138+
139+
batch = data[0]["img"][None]
140+
141+
semantic_latent = self.model.encode(batch.to(self.device))
142+
stochastic_latent = self.model.encode_stochastic(
143+
batch.to(self.device), semantic_latent, T=T_inv
144+
)
145+
146+
cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
147+
class_direction = self.classifier.classifier.weight[cls_id]
148+
normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
149+
150+
normalized_semantic_latent = self.classifier.normalize(semantic_latent)
151+
normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
152+
normalized_manipulated_semantic_latent = (
153+
normalized_semantic_latent
154+
+ normalized_manipulation_amp * normalized_class_direction
155+
)
156+
157+
manipulated_semantic_latent = self.classifier.denormalize(
158+
normalized_manipulated_semantic_latent
159+
)
160+
161+
# Render Manipulated image
162+
manipulated_img = self.model.render(
163+
stochastic_latent, manipulated_semantic_latent, T=T_step
164+
)[0]
165+
original_img = data[0]["img"]
166+
167+
model_output = []
168+
out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
169+
save_image(convert2rgb(original_img), str(out_path))
170+
model_output.append(ModelOutput(image=out_path))
171+
172+
out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
173+
save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
174+
model_output.append(ModelOutput(image=out_path))
175+
return model_output
176+
177+
178+
def convert2rgb(img, adjust_scale=True):
179+
convert_img = torch.tensor(img)
180+
if adjust_scale:
181+
convert_img = (convert_img + 1) / 2
182+
return convert_img.cpu()

0 commit comments

Comments
 (0)