Skip to content

Commit 39ca570

Browse files
authored
update to match the camera-ready version
1 parent 63d020b commit 39ca570

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+371
-0
lines changed
3.79 KB
Binary file not shown.

__pycache__/layers.cpython-38.pyc

6.77 KB
Binary file not shown.
3.77 KB
Binary file not shown.

__pycache__/models.cpython-38.pyc

2.25 KB
Binary file not shown.

cifar_mil_main.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import argparse
2+
import json
3+
from cifar_mil_trainer import *
4+
import os
5+
import torch
6+
import seaborn as sns
7+
import matplotlib.pyplot as plt
8+
9+
def get_args():
10+
11+
parser = argparse.ArgumentParser(description='MNIST MIL benchmarks:')
12+
13+
parser.add_argument("--project_name", default="MNIST-MIL")
14+
parser.add_argument('--wandb', default=False, type=bool)
15+
16+
# Model params
17+
parser.add_argument('--mode', default="softmax", choices=["softmax", "entmax", "sparsemax", "gsh"])
18+
parser.add_argument('--d_model', default=1024, type=int)
19+
parser.add_argument('--input_size', default=3*32*32, type=int)
20+
parser.add_argument('--model', default="pooling", type=str)
21+
parser.add_argument('--num_pattern', default=20, type=int)
22+
parser.add_argument('--n_heads', default=8, type=int)
23+
parser.add_argument('--scale', default=0.01)
24+
parser.add_argument('--update_steps', default=1, type=int)
25+
parser.add_argument('--dropout', default=0.7, type=float)
26+
27+
# Training params
28+
parser.add_argument('--lr', default=1e-3, type=float)
29+
parser.add_argument('--epoch', default=100, type=int)
30+
parser.add_argument('--seed', default=1111, type=int)
31+
32+
# Data params
33+
parser.add_argument('--batch_size', default=64, type=int)
34+
parser.add_argument('--train_size', default=10000, type=int)
35+
parser.add_argument('--test_size', default=5000, type=int)
36+
parser.add_argument('--pos_per_bag', default=1, type=int)
37+
parser.add_argument('--bag_size', default=10, type=int)
38+
parser.add_argument('--tgt_num', default=0, type=int)
39+
40+
args = parser.parse_args()
41+
42+
return vars(args)
43+
44+
45+
if __name__ == "__main__":
46+
47+
torch.set_num_threads(3)
48+
config = get_args()
49+
trails = 5
50+
torch.manual_seed(config["seed"])
51+
52+
bag_size = config["bag_size"]
53+
# bag_size = [5, 10, 20, 50, 100, 200, 300]
54+
models = ["softmax", "sparsemax", "entmax", "gsh"]
55+
data_log = None
56+
57+
for m in models:
58+
config["mode"] = m
59+
for t in range(trails):
60+
torch.random.manual_seed(torch.random.seed())
61+
trainer = Trainer(config, t)
62+
trail_log = trainer.train()
63+
if data_log is None:
64+
data_log = trail_log
65+
else:
66+
for k,v in data_log.items():
67+
data_log[k] = data_log[k] + trail_log[k]
68+
69+
sns.lineplot(data=data_log, x="epoch", y="train loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
70+
plt.tight_layout()
71+
plt.savefig(f'./imgs/cifar/train_loss_{bag_size}.pdf')
72+
plt.clf()
73+
74+
sns.lineplot(data=data_log, x="epoch", y="test loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
75+
plt.tight_layout()
76+
plt.savefig(f'./imgs/cifar/test_loss_{bag_size}.pdf')
77+
plt.clf()
78+
79+
sns.lineplot(data=data_log, x="epoch", y="train acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
80+
plt.tight_layout()
81+
plt.savefig(f'./imgs/cifar/train_acc_{bag_size}.pdf')
82+
plt.clf()
83+
84+
sns.lineplot(data=data_log, x="epoch", y="test acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
85+
plt.tight_layout()
86+
plt.savefig(f'./imgs/cifar/test_acc_{bag_size}.pdf')
87+
plt.clf()

cifar_mil_trainer.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from datasets.cifar10_bags import CIFARBags
2+
from torch.utils.data import DataLoader
3+
from layers import *
4+
from models import *
5+
import wandb
6+
import pandas as pd
7+
8+
class Trainer:
9+
10+
def __init__(self, config, trial) -> None:
11+
self.config = config
12+
self.trial = trial
13+
14+
if self.config["wandb"]:
15+
run = wandb.init(
16+
# Set the project where this run will be logged
17+
project=self.config["project_name"] + " good",
18+
# Track hyperparameters and run metadata
19+
config=self.config)
20+
21+
def _get_data(self):
22+
23+
trainset = CIFARBags(target_number=self.config["tgt_num"],
24+
bag_size=self.config["bag_size"],
25+
num_bag=self.config["train_size"],
26+
pos_per_bag=self.config["pos_per_bag"],
27+
seed=self.config["seed"],
28+
train=True
29+
)
30+
31+
testset = CIFARBags(target_number=self.config["tgt_num"],
32+
bag_size=self.config["bag_size"],
33+
num_bag=self.config["test_size"],
34+
pos_per_bag=self.config["pos_per_bag"],
35+
seed=self.config["seed"],
36+
train=False
37+
)
38+
39+
train_loader = DataLoader(trainset, batch_size=self.config["batch_size"], shuffle=True)
40+
test_loader = DataLoader(testset, batch_size=self.config["batch_size"], shuffle=False)
41+
42+
return train_loader, test_loader
43+
44+
def _get_model(self):
45+
46+
model = MNISTModel(input_size=self.config["input_size"],
47+
d_model=self.config["d_model"],
48+
n_heads=self.config["n_heads"],
49+
update_steps=self.config["update_steps"],
50+
dropout=self.config["dropout"],
51+
mode=self.config["mode"],
52+
scale=self.config["scale"],
53+
num_pattern=self.config['num_pattern'])
54+
55+
return model.cuda()
56+
57+
def _get_opt(self):
58+
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"], weight_decay=0.0)
59+
60+
def _get_cri(self):
61+
return torch.nn.BCEWithLogitsLoss()
62+
63+
def test_epoch(self, loader):
64+
65+
total_loss = 0.0
66+
total_cor, total_sample = 0, 0
67+
total_step = 0
68+
69+
with torch.no_grad():
70+
for x, y in loader:
71+
72+
total_sample += x.size(0)
73+
total_step += 1
74+
x, y = x.float().cuda(), y.float().cuda()
75+
pred = self.model(x)
76+
loss = self.cri(pred, y)
77+
78+
output = (pred>0.5).float()
79+
total_cor += (output == y).float().sum()
80+
total_loss += loss.item()
81+
82+
return total_loss/total_step, total_cor/total_sample
83+
84+
def train_epoch(self, loader):
85+
86+
total_loss = 0.0
87+
total_cor, total_sample = 0, 0
88+
total_step = 0
89+
90+
for x, y in loader:
91+
92+
total_step += 1
93+
total_sample += x.size(0)
94+
95+
self.opt.zero_grad()
96+
x, y = x.float().cuda(), y.float().cuda()
97+
pred = self.model(x)
98+
loss = self.cri(pred, y)
99+
loss.backward()
100+
self.opt.step()
101+
102+
output = (pred>0.5).float()
103+
total_cor += (output == y).float().sum()
104+
total_loss += loss.item()
105+
106+
return total_loss/total_step, total_cor/total_sample
107+
108+
def train(self):
109+
110+
train_loader, test_loader = self._get_data()
111+
self.model = self._get_model()
112+
self.opt = self._get_opt()
113+
self.cri = self._get_cri()
114+
115+
best_test_acc = -1
116+
117+
data_log = {
118+
'train loss':[],
119+
'train acc':[],
120+
'test loss':[],
121+
'test acc':[],
122+
'epoch':[],
123+
'model':[]
124+
}
125+
126+
self.sche = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, self.config["epoch"], eta_min=0, last_epoch=-1, verbose=False)
127+
128+
for epoch in range(1, self.config["epoch"]+1):
129+
130+
train_loss, train_acc = self.train_epoch(train_loader)
131+
test_loss, test_acc = self.test_epoch(test_loader)
132+
self.sche.step()
133+
134+
data_log['train loss'].append(train_loss)
135+
data_log['test loss'].append(test_loss)
136+
data_log['train acc'].append(train_acc.item())
137+
data_log['test acc'].append(test_acc.item())
138+
data_log['epoch'].append(epoch)
139+
data_log['model'].append(self.config['mode'])
140+
141+
if test_acc >= best_test_acc:
142+
best_test_acc = test_acc
143+
144+
if self.config["wandb"]:
145+
wandb.log({
146+
"step": epoch,
147+
"train loss": train_loss,
148+
"train acc": train_acc.item()*100,
149+
"test loss": test_loss,
150+
"test acc": test_acc.item()*100
151+
}, step=epoch)
152+
153+
if self.config["wandb"]:
154+
wandb.log({"best test acc": best_test_acc})
155+
wandb.log({"logs": data_log})
156+
157+
if self.config["wandb"]:
158+
wandb.finish()
159+
160+
return data_log

cifar_run.sh

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
python3 cifar_mil_main.py --bag_size 20
2+
python3 cifar_mil_main.py --bag_size 50
3+
python3 cifar_mil_main.py --bag_size 5
4+
python3 cifar_mil_main.py --bag_size 10
5+
python3 cifar_mil_main.py --bag_size 2
2.89 KB
Binary file not shown.

datasets/cifar10_bags.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Pytorch dataset object that loads MNIST dataset as bags."""
2+
3+
import numpy as np
4+
import torch
5+
import torch.utils.data as data_utils
6+
from torchvision import datasets, transforms
7+
import random
8+
9+
10+
class CIFARBags(data_utils.Dataset):
11+
def __init__(
12+
self,
13+
target_number=9,
14+
bag_size=10,
15+
num_bag=500,
16+
pos_per_bag=1,
17+
seed=1,
18+
train=True):
19+
self.target_number = target_number
20+
self.bag_size = bag_size
21+
self.pos_per_bag = pos_per_bag
22+
self.train = train
23+
self.num_bag = num_bag
24+
25+
self.r = np.random.RandomState(seed)
26+
27+
self.num_in_train = 50000
28+
self.num_in_test = 10000
29+
30+
if self.train:
31+
self.train_bags_list, self.train_labels_list = self._create_bags()
32+
else:
33+
self.test_bags_list, self.test_labels_list = self._create_bags()
34+
35+
def _create_bags(self):
36+
37+
transform_train = transforms.Compose([
38+
transforms.ToTensor(),
39+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
40+
])
41+
42+
transform_test = transforms.Compose([
43+
transforms.ToTensor(),
44+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
45+
])
46+
47+
if self.train:
48+
loader = data_utils.DataLoader(
49+
datasets.CIFAR10(
50+
'../datasets',
51+
train=True,
52+
download=True,
53+
transform=transform_train),
54+
batch_size=self.num_in_train,
55+
shuffle=False)
56+
else:
57+
loader = data_utils.DataLoader(
58+
datasets.CIFAR10(
59+
'../datasets',
60+
train=False,
61+
download=True,
62+
transform=transform_test),
63+
batch_size=self.num_in_test,
64+
shuffle=False)
65+
66+
for (batch_data, batch_labels) in loader:
67+
all_imgs = batch_data
68+
all_labels = batch_labels
69+
70+
bags_list = []
71+
labels_list = []
72+
73+
pos_idx = [i for i, j in enumerate(
74+
all_labels) if j == self.target_number]
75+
neg_idx = [i for i, j in enumerate(
76+
all_labels) if j != self.target_number]
77+
78+
pos_images = []
79+
neg_images = []
80+
81+
for i, img in enumerate(all_imgs):
82+
if all_labels[i] == self.target_number:
83+
pos_images.append(img)
84+
else:
85+
neg_images.append(img)
86+
87+
self.all_pos_img = pos_images
88+
self.all_neg_img = neg_images
89+
90+
for i in range(self.num_bag):
91+
92+
_pos_idx = random.sample(pos_idx,
93+
self.pos_per_bag) + random.sample(neg_idx,
94+
self.bag_size - self.pos_per_bag)
95+
_neg_idx = random.sample(neg_idx, self.bag_size)
96+
assert len(_pos_idx) == len(_neg_idx)
97+
98+
bags_list.append(all_imgs[_neg_idx])
99+
labels_list.append(1)
100+
bags_list.append(all_imgs[_pos_idx])
101+
labels_list.append(0)
102+
103+
return bags_list, torch.tensor(labels_list)
104+
105+
def __len__(self):
106+
if self.train:
107+
return len(self.train_labels_list)
108+
else:
109+
return len(self.test_labels_list)
110+
111+
def __getitem__(self, index):
112+
if self.train:
113+
bag = self.train_bags_list[index]
114+
label = self.train_labels_list[index]
115+
else:
116+
bag = self.test_bags_list[index]
117+
label = self.test_labels_list[index]
118+
119+
return bag, label

imgs/cifar/test_acc_10.pdf

14.8 KB
Binary file not shown.

imgs/cifar/test_acc_2.pdf

14.5 KB
Binary file not shown.

imgs/cifar/test_acc_20.pdf

15.5 KB
Binary file not shown.

imgs/cifar/test_acc_5.pdf

14.8 KB
Binary file not shown.

imgs/cifar/test_acc_50.pdf

14.8 KB
Binary file not shown.

imgs/cifar/test_loss_10.pdf

15.5 KB
Binary file not shown.

imgs/cifar/test_loss_2.pdf

14.9 KB
Binary file not shown.

imgs/cifar/test_loss_20.pdf

14.7 KB
Binary file not shown.

imgs/cifar/test_loss_5.pdf

14.7 KB
Binary file not shown.

imgs/cifar/test_loss_50.pdf

14.4 KB
Binary file not shown.

imgs/cifar/train_acc_10.pdf

14.9 KB
Binary file not shown.

imgs/cifar/train_acc_2.pdf

15.1 KB
Binary file not shown.

imgs/cifar/train_acc_20.pdf

15.7 KB
Binary file not shown.

imgs/cifar/train_acc_5.pdf

15.3 KB
Binary file not shown.

imgs/cifar/train_acc_50.pdf

15.4 KB
Binary file not shown.

imgs/cifar/train_loss_10.pdf

15 KB
Binary file not shown.

imgs/cifar/train_loss_2.pdf

15.1 KB
Binary file not shown.

imgs/cifar/train_loss_20.pdf

15 KB
Binary file not shown.

imgs/cifar/train_loss_5.pdf

15.1 KB
Binary file not shown.

imgs/cifar/train_loss_50.pdf

15 KB
Binary file not shown.

imgs/test_acc_10.png

54.7 KB

imgs/test_acc_100.png

29 KB

imgs/test_acc_20.png

16.1 KB

imgs/test_acc_30.png

8.45 KB

imgs/test_acc_5.png

11.9 KB

imgs/test_acc_50.png

20.1 KB

imgs/test_acc_80.png

32.7 KB

imgs/test_loss_10.png

60.9 KB

imgs/test_loss_100.png

33.1 KB

imgs/test_loss_20.png

9.81 KB

imgs/test_loss_30.png

10.9 KB

imgs/test_loss_5.png

-2.52 KB

imgs/test_loss_50.png

19.8 KB

imgs/test_loss_80.png

35.5 KB

imgs/train_acc_10.png

34.2 KB

imgs/train_acc_100.png

12.7 KB

imgs/train_acc_20.png

3.15 KB

imgs/train_acc_30.png

3.5 KB

imgs/train_acc_5.png

849 Bytes

imgs/train_acc_50.png

10.5 KB

imgs/train_acc_80.png

17.5 KB

imgs/train_loss_10.png

37.2 KB

imgs/train_loss_100.png

14.7 KB

imgs/train_loss_20.png

4.72 KB

imgs/train_loss_30.png

4.67 KB

imgs/train_loss_5.png

3.02 KB

imgs/train_loss_50.png

12.2 KB

imgs/train_loss_80.png

18.2 KB

0 commit comments

Comments
 (0)