Skip to content

Commit 4fb5ab0

Browse files
committed
Update run config
1 parent 1284d62 commit 4fb5ab0

19 files changed

+194
-46
lines changed

config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,2 @@
11
LABELS = ['drugname', 'other']
22
labels_weight = [0.99, 0.01]
3-
4-
# Image Encoder Config
5-
image_path = 'data/pills/'
6-
depth = 3
7-
size = 224

data/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
VAIPE-PP_01
2+
(Ảnh thuốc tương ứng với đơn thuốc uống trong 1 ngày)
3+
- Số lượng đơn: 1527 đơn
4+
- Số lượng bệnh nhân: ??? bệnh nhân

data/data.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616

1717
class PrescriptionPillData(Dataset):
18-
def __init__(self, json_files, mode, sentences_tokenizer):
18+
def __init__(self, json_files, mode, args):
19+
self.args = args
1920
self.text_sentences_tokenizer = AutoTokenizer.from_pretrained(
20-
sentences_tokenizer)
21+
args.text_model_name)
2122
self.json_files = json_files
2223
self.mode = mode
23-
self.transforms = get_transforms(self.mode)
24-
self.all_pill_labels = get_all_pill_label("data/all_imgs/train")
24+
self.transforms = get_transforms(self.mode, args.image_size)
25+
self.all_pill_labels = get_all_pill_label(
26+
args.data_folder + "all_imgs/train")
2527

2628
def create_graph(self, bboxes, imgw, imgh, pills_class):
2729
G = nx.Graph()
@@ -128,7 +130,8 @@ def __getitem__(self, idx):
128130
# FOR IMAGE PILLS
129131
pills_image_folder_name = self.json_files[idx].split(
130132
"/")[-1].split(".")[0]
131-
pills_image_path = CFG.image_path + self.mode + "/" + pills_image_folder_name
133+
pills_image_path = self.args.data_folder + self.args.image_path + \
134+
self.mode + "/" + pills_image_folder_name
132135
pills_image_folder = torchvision.datasets.ImageFolder(
133136
pills_image_path, transform=self.transforms)
134137
pills_class_to_idx = pills_image_folder.class_to_idx
@@ -164,19 +167,20 @@ def __getitem__(self, idx):
164167

165168
data.pills_images = pills_images[0]
166169
data.pills_images_labels = torch.Tensor(pills_images_labels)
167-
data.pills_images_labels_idx = torch.ones_like(data.pills_images_labels, dtype=int) * idx
170+
data.pills_images_labels_idx = torch.ones_like(
171+
data.pills_images_labels, dtype=int) * idx
168172
return data
169173

170174

171-
def get_transforms(mode="train"):
175+
def get_transforms(mode="train", size=224):
172176
if mode == "train":
173-
transform = transforms.Compose([transforms.Resize((CFG.size, CFG.size)),
177+
transform = transforms.Compose([transforms.Resize((size, size)),
174178
transforms.RandomRotation(10),
175179
transforms.RandomHorizontalFlip(),
176180
transforms.ToTensor(),
177181
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
178182
else:
179-
transform = transforms.Compose([transforms.Resize((CFG.size, CFG.size)),
183+
transform = transforms.Compose([transforms.Resize((size, size)),
180184
transforms.ToTensor(),
181185
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
182186
return transform
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train.py --run-name="PPMatching-Graph-Resnet18-Pretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=True --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-NonPretrain-SBertMulti-NS-0.2" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1 --negative-ratio=0.2
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-Pretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=True --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train.py --run-name="PPMatching-Graph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="TripletLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="TripletLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train.py --run-name="PPMatching-Graph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1 --data-folder="/mnt/disk2/thanhnt/ThanhNT-Data/2904_VAIPE-Matching/VAIPE-PP_02/data/"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1 --data-folder="/mnt/disk2/thanhnt/ThanhNT-Data/2904_VAIPE-Matching/VAIPE-PP_02/data/"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train.py --run-name="PPMatching-Graph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1 --data-folder="/mnt/disk2/thanhnt/ThanhNT-Data/2904_VAIPE-Matching/VAIPE-PP_03/data/"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -u train_text_img.py --run-name="PPMatching-NonGraph-Resnet18-NonPretrain-SBertMulti" --image-model-name="resnet18" --image-embedding=512 --image-trainable=True --image-pretrained=False --matching-criterion="ContrastiveLoss" --text-model-name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" --text-embedding=384 --train-batch-size=4 --val-batch-size=1 --data-folder="/mnt/disk2/thanhnt/ThanhNT-Data/2904_VAIPE-Matching/VAIPE-PP_03/data/"

train.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ def main(args):
9494
torch.cuda.manual_seed_all(args.seed)
9595

9696
print(">>>> Preparing data...")
97-
train_files = glob.glob(args.train_folder + "*.json")
98-
val_files = glob.glob(args.val_folder + "*.json")
97+
train_files = glob.glob(args.data_folder + args.train_folder + "*.json")
98+
val_files = glob.glob(args.data_folder + args.val_folder + "*.json")
9999

100100
train_loader = build_loaders(
101-
train_files, mode="train", batch_size=args.train_batch_size, sentences_tokenizer=args.text_model_name)
101+
train_files, mode="train", batch_size=args.train_batch_size, args=args)
102102
train_val_loader = build_loaders(
103-
train_files, mode="test", batch_size=args.val_batch_size, sentences_tokenizer=args.text_model_name)
103+
train_files, mode="train", batch_size=args.val_batch_size, args=args)
104104

105105
val_loader = build_loaders(
106-
val_files, mode="test", batch_size=args.val_batch_size, sentences_tokenizer=args.text_model_name)
106+
val_files, mode="test", batch_size=args.val_batch_size, args=args)
107107

108108
# Print data information
109109
print("Train files: ", len(train_files))
@@ -137,7 +137,8 @@ def main(args):
137137
val_acc = val(model, val_loader)
138138
print("Val accuracy: ", val_acc)
139139

140-
wandb.log({"train_loss": train_loss,"train_acc": train_val_acc, "val_acc": val_acc})
140+
wandb.log({"train_loss": train_loss,
141+
"train_acc": train_val_acc, "val_acc": val_acc})
141142

142143
# if val_acc > best_accuracy:
143144
# best_accuracy = val_acc
@@ -148,7 +149,7 @@ def main(args):
148149
if __name__ == '__main__':
149150
parse_args = option()
150151

151-
wandb.init(entity="aiotlab", project="VAIPE-Pills-Prescription-Matching", group="Graph", name=parse_args.run_name, # mode="disabled",
152+
wandb.init(entity="aiotlab", project="VAIPE-Pills-Prescription-Matching", group="Graph-PP_02", name=parse_args.run_name, # mode="disabled",
152153
config={
153154
"train_batch_size": parse_args.train_batch_size,
154155
"val_batch_size": parse_args.val_batch_size,

train_text_img.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import glob
2+
import torch
3+
from tqdm import tqdm
4+
from models.image_text import ImageTextMatching
5+
from utils.metrics import ContrastiveLoss, TripletLoss
6+
import wandb
7+
from utils.utils import build_loaders, calculate_matching_loss
8+
from utils.option import option
9+
import warnings
10+
warnings.filterwarnings("ignore")
11+
12+
13+
def train(model, train_loader, optimizer, matching_criterion, epoch, negative_ratio=None):
14+
model.train()
15+
train_loss = []
16+
wandb.watch(model)
17+
with tqdm(train_loader, desc=f"Train Epoch {epoch}") as train_bar:
18+
for data in train_bar:
19+
data = data.cuda()
20+
optimizer.zero_grad()
21+
pre_loss = []
22+
23+
image_features, sentences_features = model(data)
24+
25+
loss = calculate_matching_loss(
26+
image_features, sentences_features, data.pills_label, data.pills_images_labels, matching_criterion, negative_ratio)
27+
28+
loss.backward()
29+
optimizer.step()
30+
pre_loss.append(loss.item())
31+
32+
train_loss.append(sum(pre_loss) / len(pre_loss))
33+
train_bar.set_postfix(loss=train_loss[-1])
34+
35+
return sum(train_loss) / len(train_loss)
36+
37+
38+
def val(model, val_loader):
39+
model.eval()
40+
matching_acc = []
41+
with torch.no_grad():
42+
for data in tqdm(val_loader, desc="Validation"):
43+
data = data.cuda()
44+
correct = []
45+
image_features, sentences_features = model(data)
46+
# For Matching
47+
similarity = image_features @ sentences_features.t()
48+
_, predicted = torch.max(similarity, 1)
49+
mapping_predicted = data.pills_label[predicted]
50+
51+
correct.append(mapping_predicted.eq(
52+
data.pills_images_labels).sum().item() / len(data.pills_images_labels))
53+
54+
matching_acc.append(sum(correct) / len(correct))
55+
56+
final_accuracy = sum(matching_acc) / len(matching_acc)
57+
58+
return final_accuracy
59+
60+
61+
def main(args):
62+
print("CUDA status: ", args.cuda)
63+
torch.cuda.manual_seed_all(args.seed)
64+
65+
print(">>>> Preparing data...")
66+
train_files = glob.glob(args.data_folder + args.train_folder + "*.json")
67+
val_files = glob.glob(args.data_folder + args.val_folder + "*.json")
68+
69+
train_loader = build_loaders(
70+
train_files, mode="train", batch_size=args.train_batch_size, args=args)
71+
72+
train_val_loader = build_loaders(
73+
train_files, mode="train", batch_size=args.val_batch_size, args=args)
74+
75+
val_loader = build_loaders(
76+
val_files, mode="test", batch_size=args.val_batch_size, args=args)
77+
78+
# Print data information
79+
print("Train files: ", len(train_files))
80+
print("Val files: ", len(val_files))
81+
82+
print(">>>> Preparing model...")
83+
model = ImageTextMatching(args).cuda()
84+
85+
print(">>>> Preparing optimizer...")
86+
if args.matching_criterion == "ContrastiveLoss":
87+
matching_criterion = ContrastiveLoss()
88+
elif args.matching_criterion == "TripletLoss":
89+
matching_criterion = TripletLoss()
90+
91+
# Define optimizer
92+
optimizer = torch.optim.AdamW(
93+
model.parameters(), lr=args.lr, weight_decay=5e-4)
94+
95+
best_accuracy = 0
96+
print(">>>> Training...")
97+
for epoch in range(1, args.epochs + 1):
98+
train_loss = train(model, train_loader, optimizer,
99+
matching_criterion, epoch, negative_ratio=args.negative_ratio)
100+
print(">>>> Train Validation...")
101+
# break
102+
train_val_acc = val(model, train_val_loader)
103+
print("Train accuracy: ", train_val_acc)
104+
print(">>>> Test Validation...")
105+
val_acc = val(model, val_loader)
106+
print("Val accuracy: ", val_acc)
107+
108+
wandb.log({"train_loss": train_loss,
109+
"train_acc": train_val_acc, "val_acc": val_acc})
110+
# if val_acc > best_accuracy:
111+
# best_accuracy = val_acc
112+
# print(">>>> Saving model...")
113+
# torch.save(model.state_dict(), args.save_folder + "best_model.pth")
114+
115+
116+
if __name__ == '__main__':
117+
parse_args = option()
118+
119+
wandb.init(entity="aiotlab", project="VAIPE-Pills-Prescription-Matching", group="Text-Matching", name=parse_args.run_name, # mode="disabled",
120+
config={
121+
"train_batch_size": parse_args.train_batch_size,
122+
"val_batch_size": parse_args.val_batch_size,
123+
"epochs": parse_args.epochs,
124+
"lr": parse_args.lr,
125+
"seed": parse_args.seed
126+
})
127+
128+
args = wandb.config
129+
wandb.define_metric("val_acc", summary="max")
130+
main(parse_args)
131+
wandb.finish()

utils/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reset(self):
2727

2828

2929
class TripletLoss(nn.Module):
30-
def __init__(self, margin=2.0):
30+
def __init__(self, margin=1.0):
3131
super(TripletLoss, self).__init__()
3232
self.margin = margin
3333
self.cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
@@ -39,8 +39,8 @@ def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.
3939
distance_positive = self.calc_cosinsimilarity(anchor, positive)
4040
distance_negative = self.calc_cosinsimilarity(anchor, negative)
4141

42-
losses = torch.relu(- distance_positive +
43-
distance_negative + self.margin)
42+
losses = torch.relu(- torch.mean(distance_positive) +
43+
torch.mean(distance_negative) + self.margin)
4444

4545
return losses.mean()
4646

utils/option.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ def option():
1616
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
1717

1818
# Data Init
19+
parser.add_argument('--data-folder', type=str, default='data/')
1920
parser.add_argument('--train-folder', type=str,
20-
default="data/prescriptions/train/")
21+
default="prescriptions/train/")
2122
parser.add_argument('--val-folder', type=str,
22-
default="data/prescriptions/test/")
23+
default="prescriptions/test/")
2324

24-
parser.add_argument('--image-path', type=str, default="data/pills/")
25+
parser.add_argument('--image-path', type=str, default="pills/")
2526
parser.add_argument('--depth', type=int, default=3)
26-
parser.add_argument('--size', type=int, default=224)
27+
parser.add_argument('--image-size', type=int, default=224)
2728

2829
parser.add_argument('--train-batch-size', type=int, default=1)
2930
parser.add_argument('--val-batch-size', type=int, default=1)
@@ -59,6 +60,7 @@ def option():
5960
# matching
6061
parser.add_argument('--matching-criterion', type=str,
6162
default="ContrastiveLoss")
63+
parser.add_argument('--negative-ratio', type=float, default=None)
6264

6365
# Model Save
6466
parser.add_argument('--save-model', type=bool, default=False)

utils/utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import math
55

66

7-
def build_loaders(files, mode="train", batch_size=1, sentences_tokenizer="sentence-transformers/paraphrase-mpnet-base-v2"):
8-
dataset = PrescriptionPillData(files, mode, sentences_tokenizer)
7+
def build_loaders(files, mode="train", batch_size=1, args=None):
8+
dataset = PrescriptionPillData(files, mode, args)
99
dataloader = DataLoader(
1010
dataset,
1111
batch_size=batch_size,
@@ -15,23 +15,29 @@ def build_loaders(files, mode="train", batch_size=1, sentences_tokenizer="senten
1515

1616

1717
def creat_batch_triplet(image_aggregation, text_embedding_drugname, text_embedding_labels, pills_images_labels):
18-
anchor, positive, negative = torch.tensor([]).cuda(), torch.tensor([]).cuda(), torch.tensor([]).cuda()
18+
anchor, positive, negative = torch.tensor([]).cuda(
19+
), torch.tensor([]).cuda(), torch.tensor([]).cuda()
1920

2021
for idx, label in enumerate(pills_images_labels):
2122
positive_idx = text_embedding_labels.eq(label)
2223
negative_idx = text_embedding_labels.ne(label)
2324

24-
anchor = torch.cat((anchor, image_aggregation[idx].unsqueeze(0).unsqueeze(0)))
25-
positive = torch.cat((positive, text_embedding_drugname[positive_idx].unsqueeze(0)))
25+
anchor = torch.cat(
26+
(anchor, image_aggregation[idx].unsqueeze(0).unsqueeze(0)))
27+
positive = torch.cat(
28+
(positive, text_embedding_drugname[positive_idx].unsqueeze(0)))
2629

2730
if sum(negative_idx) == 0:
28-
negative = torch.cat((negative, torch.zeros_like(image_aggregation[idx]).unsqueeze(0).unsqueeze(0)))
31+
negative = torch.cat((negative, torch.zeros_like(
32+
image_aggregation[idx]).unsqueeze(0).unsqueeze(0)))
2933
else:
30-
negative = torch.cat((negative, text_embedding_drugname[negative_idx].unsqueeze(0)))
34+
negative = torch.cat(
35+
(negative, text_embedding_drugname[negative_idx].unsqueeze(0)))
3136

3237
return anchor, positive, negative
3338

34-
def calculate_matching_loss(image_aggregation, text_embedding_drugname, text_embedding_labels, pills_images_labels, matching_criterion):
39+
40+
def calculate_matching_loss(image_aggregation, text_embedding_drugname, text_embedding_labels, pills_images_labels, matching_criterion, negative_ratio=None):
3541

3642
loss = []
3743
for idx, label in enumerate(pills_images_labels):
@@ -42,12 +48,16 @@ def calculate_matching_loss(image_aggregation, text_embedding_drugname, text_emb
4248
positive = text_embedding_drugname[positive_idx]
4349
negative = text_embedding_drugname[negative_idx]
4450

51+
if negative_ratio is not None:
52+
# get random negative samples
53+
negative = negative[torch.randperm(
54+
len(negative))[:math.ceil(len(negative) * negative_ratio)]]
55+
4556
loss.append(matching_criterion(anchor, positive, negative))
4657

4758
return torch.mean(torch.stack(loss))
4859

4960

50-
5161
def creat_batch_triplet_random(image_features, text_embedding_drugname, text_embedding_labels, pills_images_labels, ratio=0.2):
5262
anchor, positive, negative = torch.tensor([]).cuda(
5363
), torch.tensor([]).cuda(), torch.tensor([]).cuda()
@@ -75,12 +85,3 @@ def creat_batch_triplet_random(image_features, text_embedding_drugname, text_emb
7585

7686
# print(anchor.shape, positive.shape, negative.shape)
7787
return anchor, positive, negative
78-
79-
80-
# create tensor size [10, 256]
81-
# image_features = torch.randn(5, 256).cuda()
82-
# text_embedding_drugname = torch.randn(4, 256).cuda()
83-
# text_embedding_labels = torch.tensor([1, 1, 0, -2, -1]).cuda()
84-
# pills_images_labels = torch.tensor([1, 0, 1, 0, 1]).cuda()
85-
# creat_batch_triplet_random(image_features, text_embedding_drugname,
86-
# text_embedding_labels, pills_images_labels, 0.2)

0 commit comments

Comments
 (0)