Skip to content

Commit e62ffab

Browse files
author
GE ZHENG
committed
first commit
1 parent c5779ce commit e62ffab

File tree

9 files changed

+65
-0
lines changed

9 files changed

+65
-0
lines changed

inference.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import os
4+
from scipy import misc
5+
import argparse
6+
import sys
7+
8+
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
9+
output_folder = "./test_output"
10+
11+
def rgba2rgb(img):
12+
return img[:,:,:3]*np.expand_dims(img[:,:,3],2)
13+
14+
def main(args):
15+
16+
if not os.path.exists(output_folder):
17+
os.mkdir(output_folder)
18+
19+
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
20+
with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
21+
saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
22+
saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
23+
image_batch = tf.get_collection('image_batch')[0]
24+
pred_mattes = tf.get_collection('mask')[0]
25+
26+
if args.rgb_folder:
27+
rgb_pths = os.listdir(args.rgb_folder)
28+
for rgb_pth in rgb_pths:
29+
rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
30+
if rgb.shape[2]==4:
31+
rgb = rgba2rgb(rgb)
32+
origin_shape = rgb.shape
33+
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
34+
35+
feed_dict = {image_batch:rgb}
36+
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
37+
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
38+
misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
39+
40+
else:
41+
rgb = misc.imread(args.rgb)
42+
if rgb.shape[2]==4:
43+
rgb = rgba2rgb(rgb)
44+
origin_shape = rgb.shape[:2]
45+
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
46+
47+
feed_dict = {image_batch:rgb}
48+
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
49+
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
50+
misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
51+
52+
def parse_arguments(argv):
53+
parser = argparse.ArgumentParser()
54+
55+
parser.add_argument('--rgb', type=str,
56+
help='input rgb',default = None)
57+
parser.add_argument('--rgb_folder', type=str,
58+
help='input rgb',default = None)
59+
parser.add_argument('--gpu_fraction', type=float,
60+
help='how much gpu is needed, usually 4G is enough',default = 1.0)
61+
return parser.parse_args(argv)
62+
63+
64+
if __name__ == '__main__':
65+
main(parse_arguments(sys.argv[1:]))

meta_graph/my-model.meta

956 KB
Binary file not shown.

test/animal1.jpg

11.6 KB
Loading

test/animal2.jpg

11.6 KB
Loading

test/animal3.jpg

10.8 KB
Loading

test/goat.jpg

8.41 KB
Loading

test/goat2.jpg

7.12 KB
Loading

test/goat3.jpg

8.32 KB
Loading

test/sofa.jpg

66.1 KB
Loading

0 commit comments

Comments
 (0)