Skip to content

Commit 298efcb

Browse files
author
guyuchao
committed
update codes
1 parent 0cb2102 commit 298efcb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+1559
-0
lines changed

.idea/gan.iml

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

+745
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Criterion.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import numpy as np
2+
from sklearn.metrics import precision_recall_curve
3+
import os
4+
class Criterion:
5+
def precision_recall(self,dst,epoch,ytrue,ypred):
6+
precision, recall, thresholds = precision_recall_curve(ytrue, ypred)
7+
np.save(os.path.join(dst,"P%s"%epoch),precision)
8+
np.save(os.path.join(dst, "R%s" % epoch), recall)
9+
np.save(os.path.join(dst, "T%s" % epoch), thresholds)

datasets.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import os.path as osp
3+
import numpy as np
4+
from PIL import Image
5+
from gycutils.gycaug import ColorAug,Add_Gaussion_noise,Random_horizontal_flip,Random_vertical_flip,Compose_imglabel,Random_crop
6+
import collections
7+
import torch
8+
import torchvision
9+
from transform import ReLabel, ToLabel, Scale
10+
from torch.utils import data
11+
from transform import HorizontalFlip, VerticalFlip
12+
from torchvision.transforms import Compose
13+
from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor
14+
15+
def default_loader(path):
16+
return Image.open(path)
17+
18+
class VOCDataSet(data.Dataset):
19+
def __init__(self, root, split="train", img_transform=None, label_transform=None,image_label_transform=None):
20+
self.root = root
21+
self.split = split
22+
# self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
23+
self.files = collections.defaultdict(list)
24+
self.img_transform = img_transform
25+
self.label_transform = label_transform
26+
self.image_label_transform=image_label_transform
27+
self.h_flip = HorizontalFlip()
28+
self.v_flip = VerticalFlip()
29+
30+
data_dir = osp.join(root, "eyedata",split)
31+
# for split in ["train", "trainval", "val"]:
32+
imgsets_dir = osp.join(data_dir, "img")
33+
for name in os.listdir(imgsets_dir):
34+
name = os.path.splitext(name)[0]
35+
img_file = osp.join(data_dir, "img/%s.tif" % name)
36+
label_file = osp.join(data_dir, "label/%s.gif" % name)
37+
self.files[split].append({
38+
"img": img_file,
39+
"label": label_file
40+
})
41+
42+
def __len__(self):
43+
return len(self.files[self.split])
44+
45+
def __getitem__(self, index):
46+
datafiles = self.files[self.split][index]
47+
48+
img_file = datafiles["img"]
49+
img = Image.open(img_file).convert('RGB')
50+
# img = img.resize((256, 256), Image.NEAREST)
51+
# img = np.array(img, dtype=np.uint8)
52+
53+
label_file = datafiles["label"]
54+
label = Image.open(label_file).convert("P")
55+
#label_size = label.size
56+
# label image has categorical value, not continuous, so we have to
57+
# use NEAREST not BILINEAR
58+
# label = label.resize((256, 256), Image.NEAREST)
59+
# label = np.array(label, dtype=np.uint8)
60+
# label[label == 255] = 21
61+
62+
if self.image_label_transform is not None:
63+
img,label=self.image_label_transform(img,label)
64+
65+
if self.img_transform is not None:
66+
imgs= self.img_transform(img)
67+
# img_h = self.img_transform(self.h_flip(img))
68+
# img_v = self.img_transform(self.v_flip(img))
69+
70+
#imgs = [img_o]
71+
72+
#else:
73+
#imgs = img
74+
75+
if self.label_transform is not None:
76+
labels= self.label_transform(label)
77+
# label_h = self.label_transform(self.h_flip(label))
78+
# label_v = self.label_transform(self.v_flip(label))
79+
#labels = [label_o]
80+
#else:
81+
#labels = label
82+
return imgs, labels
83+
84+
if __name__ == '__main__':
85+
86+
input_transform = Compose([
87+
ColorAug(),
88+
Add_Gaussion_noise(prob=0.5),
89+
#Scale((512, 512), Image.BILINEAR),
90+
ToTensor(),
91+
Normalize([.485, .456, .406], [.229, .224, .225]),
92+
93+
])
94+
target_transform = Compose([
95+
#Scale((512, 512), Image.NEAREST),
96+
#ToSP(512),
97+
ToLabel(),
98+
ReLabel(255, 1),
99+
])
100+
101+
img_label_transform = Compose_imglabel([
102+
Random_crop(512,512),
103+
Random_horizontal_flip(0.5),
104+
Random_vertical_flip(0.5),
105+
])
106+
dst = VOCDataSet("./", img_transform=input_transform,label_transform=target_transform,image_label_transform=img_label_transform)
107+
trainloader = data.DataLoader(dst, batch_size=1)
108+
109+
for i, data in enumerate(trainloader):
110+
imgs, labels = data
111+
112+
113+

eyedata/train/img/21_training.tif

741 KB
Binary file not shown.

eyedata/train/img/22_training.tif

734 KB
Binary file not shown.

eyedata/train/img/23_training.tif

639 KB
Binary file not shown.

eyedata/train/img/24_training.tif

679 KB
Binary file not shown.

eyedata/train/img/25_training.tif

706 KB
Binary file not shown.

eyedata/train/img/26_training.tif

685 KB
Binary file not shown.

eyedata/train/img/27_training.tif

713 KB
Binary file not shown.

eyedata/train/img/28_training.tif

737 KB
Binary file not shown.

eyedata/train/img/29_training.tif

722 KB
Binary file not shown.

eyedata/train/img/30_training.tif

687 KB
Binary file not shown.

eyedata/train/img/31_training.tif

756 KB
Binary file not shown.

eyedata/train/img/32_training.tif

722 KB
Binary file not shown.

eyedata/train/img/33_training.tif

740 KB
Binary file not shown.

eyedata/train/img/34_training.tif

800 KB
Binary file not shown.

eyedata/train/img/35_training.tif

780 KB
Binary file not shown.

eyedata/train/img/36_training.tif

730 KB
Binary file not shown.

eyedata/train/img/37_training.tif

766 KB
Binary file not shown.

eyedata/train/img/38_training.tif

713 KB
Binary file not shown.

eyedata/train/img/39_training.tif

732 KB
Binary file not shown.

eyedata/train/img/40_training.tif

677 KB
Binary file not shown.

eyedata/train/label/21_training.gif

10 KB

eyedata/train/label/22_training.gif

12.1 KB

eyedata/train/label/23_training.gif

7.9 KB

eyedata/train/label/24_training.gif

13 KB

eyedata/train/label/25_training.gif

11.2 KB

eyedata/train/label/26_training.gif

10.1 KB

eyedata/train/label/27_training.gif

11.7 KB

eyedata/train/label/28_training.gif

12.1 KB

eyedata/train/label/29_training.gif

11.3 KB

eyedata/train/label/30_training.gif

10.3 KB

eyedata/train/label/31_training.gif

8.63 KB

eyedata/train/label/32_training.gif

10.7 KB

eyedata/train/label/33_training.gif

11 KB

eyedata/train/label/34_training.gif

11.1 KB

eyedata/train/label/35_training.gif

10.8 KB

eyedata/train/label/36_training.gif

12.8 KB

eyedata/train/label/37_training.gif

10.9 KB

eyedata/train/label/38_training.gif

11.2 KB

eyedata/train/label/39_training.gif

11.4 KB

eyedata/train/label/40_training.gif

10.4 KB

eyedata/val/img/01_test.tif

713 KB
Binary file not shown.

eyedata/val/img/02_test.tif

672 KB
Binary file not shown.

eyedata/val/img/03_test.tif

669 KB
Binary file not shown.

eyedata/val/img/04_test.tif

776 KB
Binary file not shown.

eyedata/val/img/05_test.tif

707 KB
Binary file not shown.

eyedata/val/img/06_test.tif

711 KB
Binary file not shown.

eyedata/val/img/07_test.tif

773 KB
Binary file not shown.

eyedata/val/img/08_test.tif

718 KB
Binary file not shown.

eyedata/val/img/09_test.tif

710 KB
Binary file not shown.

eyedata/val/img/10_test.tif

723 KB
Binary file not shown.

eyedata/val/img/11_test.tif

702 KB
Binary file not shown.

eyedata/val/img/12_test.tif

755 KB
Binary file not shown.

eyedata/val/img/13_test.tif

710 KB
Binary file not shown.

eyedata/val/img/14_test.tif

752 KB
Binary file not shown.

eyedata/val/img/15_test.tif

763 KB
Binary file not shown.

eyedata/val/img/16_test.tif

729 KB
Binary file not shown.

eyedata/val/img/17_test.tif

749 KB
Binary file not shown.

eyedata/val/img/18_test.tif

727 KB
Binary file not shown.

eyedata/val/img/19_test.tif

717 KB
Binary file not shown.

eyedata/val/img/20_test.tif

695 KB
Binary file not shown.

eyedata/val/label/01_test.gif

11.3 KB

eyedata/val/label/02_test.gif

11.8 KB

eyedata/val/label/03_test.gif

11.3 KB

eyedata/val/label/04_test.gif

10.8 KB

eyedata/val/label/05_test.gif

11.7 KB

eyedata/val/label/06_test.gif

12.2 KB

eyedata/val/label/07_test.gif

11.3 KB

eyedata/val/label/08_test.gif

11.1 KB

eyedata/val/label/09_test.gif

11.1 KB

eyedata/val/label/10_test.gif

11.6 KB

eyedata/val/label/11_test.gif

12.6 KB

eyedata/val/label/12_test.gif

11.1 KB

eyedata/val/label/13_test.gif

12.6 KB

eyedata/val/label/14_test.gif

10.8 KB

eyedata/val/label/15_test.gif

9.65 KB

eyedata/val/label/16_test.gif

11.4 KB

eyedata/val/label/17_test.gif

10.7 KB

eyedata/val/label/18_test.gif

9.97 KB

eyedata/val/label/19_test.gif

11.1 KB

eyedata/val/label/20_test.gif

9.47 KB

gan.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from torch import nn
2+
from torchvision import models
3+
import torch.nn.functional as F
4+
import torch
5+
6+
7+
class block(nn.Module):
8+
def __init__(self,in_filters,n_filters):
9+
super(block,self).__init__()
10+
self.deconv1 = nn.Sequential(
11+
nn.Conv2d(in_filters, n_filters, 3, stride=1, padding=1),
12+
nn.BatchNorm2d(n_filters),
13+
nn.ReLU())
14+
def forward(self, x):
15+
x=self.deconv1(x)
16+
return x
17+
18+
class generator(nn.Module):
19+
# initializers
20+
def __init__(self, n_filters=32):
21+
super(generator, self).__init__()
22+
self.down1=nn.Sequential(
23+
block(3,n_filters),
24+
block(n_filters,n_filters),
25+
nn.MaxPool2d((2,2)))
26+
self.down2 = nn.Sequential(
27+
block(n_filters, 2*n_filters),
28+
block(2*n_filters, 2*n_filters),
29+
nn.MaxPool2d((2, 2)))
30+
self.down3 = nn.Sequential(
31+
block(2*n_filters, 4*n_filters),
32+
block(4*n_filters, 4*n_filters),
33+
nn.MaxPool2d((2, 2)))
34+
self.down4 = nn.Sequential(
35+
block(4*n_filters, 8 * n_filters),
36+
block(8 * n_filters, 8 * n_filters),
37+
nn.MaxPool2d((2, 2)))
38+
self.down5 = nn.Sequential(
39+
block(8 * n_filters, 16 * n_filters),
40+
block(16 * n_filters, 16 * n_filters))
41+
42+
self.up1=nn.Sequential(
43+
block(16 * n_filters+8*n_filters, 8 * n_filters),
44+
block(8 * n_filters, 8 * n_filters))
45+
self.up2 = nn.Sequential(
46+
block(8 * n_filters+4*n_filters, 4 * n_filters),
47+
block(4 * n_filters, 4 * n_filters))
48+
self.up3 = nn.Sequential(
49+
block(4 * n_filters+2*n_filters,2 * n_filters),
50+
block(2 * n_filters, 2 * n_filters))
51+
self.up4 = nn.Sequential(
52+
block(2 * n_filters+n_filters, n_filters),
53+
block( n_filters, n_filters))
54+
55+
self.out=nn.Sequential(
56+
nn.Conv2d(n_filters,1,kernel_size=1)
57+
)
58+
# forward method
59+
def forward(self, x):
60+
#print(x.size())
61+
x1=self.down1(x)
62+
#print(x1.size())
63+
x2=self.down2(x1)
64+
#print(x2.size())
65+
x3 = self.down3(x2)
66+
#print(x3.size())
67+
x4 = self.down4(x3)
68+
#print(x4.size())
69+
x5 = self.down5(x4)
70+
#print(x5.size())
71+
x = self.up1(F.upsample(torch.cat((x4,x5),dim=1),scale_factor=2))
72+
x = self.up2(F.upsample(torch.cat((x, x3), dim=1), scale_factor=2))
73+
x = self.up3(F.upsample(torch.cat((x, x2), dim=1), scale_factor=2))
74+
x = self.up4(F.upsample(torch.cat((x, x1), dim=1), scale_factor=2))
75+
x=F.sigmoid(self.out(x))
76+
return x#b,1,w,h
77+
78+
class discriminator(nn.Module):
79+
def __init__(self,n_filters):
80+
super(discriminator,self).__init__()
81+
self.down1 = nn.Sequential(
82+
block(4, n_filters),
83+
block(n_filters, n_filters),
84+
nn.MaxPool2d((2, 2)))
85+
self.down2 = nn.Sequential(
86+
block(n_filters, 2 * n_filters),
87+
block(2 * n_filters, 2 * n_filters),
88+
nn.MaxPool2d((2, 2)))
89+
self.down3 = nn.Sequential(
90+
block(2 * n_filters, 4 * n_filters),
91+
block(4 * n_filters, 4 * n_filters),
92+
nn.MaxPool2d((2, 2)))
93+
self.down4 = nn.Sequential(
94+
block(4 * n_filters, 8 * n_filters),
95+
block(8 * n_filters, 8 * n_filters),
96+
nn.MaxPool2d((2, 2)))
97+
self.down5 = nn.Sequential(
98+
block(8 * n_filters, 16 * n_filters),
99+
block(16 * n_filters, 16 * n_filters))
100+
self.out = nn.Linear(16*n_filters,1)
101+
def forward(self, x):
102+
x=self.down1(x)
103+
#print(x.size())
104+
x = self.down2(x)
105+
#print(x.size())
106+
x = self.down3(x)
107+
#print(x.size())
108+
x = self.down4(x)
109+
#print(x.size())
110+
x = self.down5(x)
111+
#print(x.size())
112+
x=F.avg_pool2d(x, kernel_size=x.size()[2:]).view(x.size()[0], -1)
113+
114+
115+
x=self.out(x)
116+
x = F.sigmoid(x)
117+
#print(x.size())
118+
return x#b,1
119+
120+
121+
if __name__=='__main__':
122+
from torch.nn.functional import Variable
123+
D=discriminator(32).cuda()
124+
t=Variable(torch.ones((2,4,512,512)).cuda())
125+
print(D(t).size())

gycutils/__init__.py

Whitespace-only changes.
133 Bytes
Binary file not shown.
6.67 KB
Binary file not shown.
727 Bytes
Binary file not shown.
990 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)