Skip to content

Commit 98015df

Browse files
committed
refactor
1 parent 0249fbc commit 98015df

13 files changed

+149
-251
lines changed

Influence_function/EIF_utils.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,6 @@
66
import torch.nn.functional as F
77
import math
88

9-
# @torch.no_grad()
10-
# def calc_loss_train_relabel(model, dl, relabel_candidate, criterion, indices=None):
11-
#
12-
# l_all = {}
13-
# model.eval()
14-
# for ct, (x, t, ind) in enumerate(dl):
15-
# torch.cuda.empty_cache()
16-
# if ind.item() in indices:
17-
# y = relabel_candidate[ind.item()]
18-
# x = x.expand(len(y), x.size()[1], x.size()[2], x.size()[3])
19-
# m = model(x)
20-
# l = criterion.debug(m, None, y) # (nb_classes, )
21-
# l_all[ind.item()] = l.detach().cpu().numpy()
22-
# pass
23-
# l_final = []
24-
# for ind in indices:
25-
# l_final.append(l_all[ind])
26-
# l_final = np.asarray(l_final)
27-
# return l_final # (N, nb_classes)
28-
#
29-
# def loss_change_train_relabel(model, criterion, dl_tr, relabel_candidate, params_prev, params_cur, indices):
30-
#
31-
# weight_orig = model.module[-1].weight.data # cache original parameters
32-
# model.module[-1].weight.data = params_prev
33-
# l_prev = calc_loss_train_relabel(model, dl_tr, relabel_candidate, criterion, indices) # (N, nb_classes)
34-
#
35-
# model.module[-1].weight.data = params_cur
36-
# l_cur = calc_loss_train_relabel(model, dl_tr, relabel_candidate, criterion, indices) # (N, nb_classes)
37-
#
38-
# model.module[-1].weight.data = weight_orig # dont forget to revise the weights back to the original
39-
# return l_prev, l_cur
409

4110
@torch.no_grad()
4211
def calc_loss_train(model, dl, criterion, indices=None):
@@ -45,7 +14,6 @@ def calc_loss_train(model, dl, criterion, indices=None):
4514
'''
4615
l = []
4716
model.eval()
48-
4917
for ct, (x, t, _) in tqdm(enumerate(dl)):
5018
x, t = x.cuda(), t.cuda()
5119
m = model(x)
@@ -72,7 +40,7 @@ def loss_change_train(model, criterion, dl_tr, params_prev, params_cur):
7240

7341
def calc_inter_dist_pair(feat_cls1, feat_cls2):
7442
'''
75-
Calculate d(confusion pair)
43+
Calculate d(p_c)
7644
'''
7745
feat_cls1 = F.normalize(feat_cls1, p=2, dim=-1) # L2 normalization
7846
feat_cls2 = F.normalize(feat_cls2, p=2, dim=-1)
@@ -86,15 +54,14 @@ def calc_inter_dist_pair(feat_cls1, feat_cls2):
8654

8755
def grad_confusion_pair(model, all_features, wrong_indices, confusion_indices):
8856
'''
89-
Calculate \partial d(confusion pair) / \partial theta
57+
Calculate \triangle d(p_c) / \triangle theta
9058
'''
9159
cls_features = all_features[wrong_indices]
9260
confuse_cls_features = all_features[confusion_indices]
9361

9462
model.zero_grad()
9563
model.eval()
96-
cls_features = cls_features.cuda()
97-
confuse_cls_features = confuse_cls_features.cuda()
64+
cls_features, confuse_cls_features = cls_features.cuda(), confuse_cls_features.cuda()
9865

9966
feature1 = model.module[-1](cls_features) # (N', 512)
10067
feature2 = model.module[-1](confuse_cls_features) # (N', 512)
@@ -110,7 +77,7 @@ def grad_confusion_pair(model, all_features, wrong_indices, confusion_indices):
11077
def grad_confusion(model, all_features, cls, confusion_classes,
11178
pred, label, nn_indices):
11279
'''
113-
Calculate \partial avg{d(confusion pair)} / \partial theta
80+
Calculate \triangle Avg{d(p_c)} / \triangle theta
11481
'''
11582
pred = pred.detach().cpu().numpy()
11683
label = label.detach().cpu().numpy()

Influence_function/IF_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def inverse_hessian_product(model, criterion, v, dl_tr,
5959
hv = hessian_vector_product(loss, params, cur_estimate) # get hvp
6060
# Inverse Hessian product Update: v + (I - Hessian_at_x) * cur_estimate
6161
cur_estimate = [_v + (1 - damping) * _h_e - _hv.detach().cpu() / scale for _v, _h_e, _hv in zip(v, cur_estimate, hv)]
62-
pass
6362

6463
inverse_hvp = [b.detach().cpu() / scale for b in cur_estimate] # "In the loop, we scale the Hessian down by scale, which means that the estimate of the inverse Hessian-vector product will be scaled up by scale. The last division corrects for this scaling."
6564
return inverse_hvp # I didn't divide it by number of recursions

Influence_function/influence_function.py

Lines changed: 105 additions & 49 deletions
Large diffs are not rendered by default.

Influence_function/sample_relabel.py

Lines changed: 13 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -66,100 +66,28 @@ def getNN_indices(self, embedding, label):
6666

6767
return nn_indices, nn_label, nn_indices_same_cls
6868

69-
def vis_pairs(self, wrong_indices, confuse_indices, wrong_samecls_indices,
70-
dl, base_dir='Grad_Test'):
71-
'''Visualize all confusion pairs'''
72-
assert len(wrong_indices) == len(confuse_indices)
73-
assert len(wrong_indices) == len(wrong_samecls_indices)
74-
75-
os.makedirs('./{}/{}'.format(base_dir, self.dataset_name), exist_ok=True)
76-
model_copy = self._load_model()
77-
model_copy.eval()
78-
79-
for ind1, ind2, ind3 in zip(wrong_indices, confuse_indices, wrong_samecls_indices):
80-
# cam_extractor1 = GradCAMCustomize(model_copy, target_layer=model_copy.module[0].base.layer4) # to last layer
81-
# cam_extractor2 = GradCAMCustomize(model_copy, target_layer=model_copy.module[0].base.layer4) # to last layer
82-
83-
# Get the two embeddings first
84-
img1 = to_pil_image(read_image(dl.dataset.im_paths[ind1]))
85-
img2 = to_pil_image(read_image(dl.dataset.im_paths[ind2]))
86-
img3 = to_pil_image(read_image(dl.dataset.im_paths[ind3]))
87-
88-
# cam_extractor1._hooks_enabled = True
89-
# model_copy.zero_grad()
90-
# emb1 = model_copy(dl.dataset.__getitem__(ind1)[0].unsqueeze(0).cuda())
91-
# emb2 = model_copy(dl.dataset.__getitem__(ind2)[0].unsqueeze(0).cuda())
92-
# activation_map2 = cam_extractor1(torch.dot(emb1.detach().squeeze(0), emb2.squeeze(0)))
93-
# result2, _ = overlay_mask(img2, to_pil_image(activation_map2[0].detach().cpu(), mode='F'), alpha=0.5)
94-
#
95-
# cam_extractor2._hooks_enabled = True
96-
# model_copy.zero_grad()
97-
# emb1 = model_copy(dl.dataset.__getitem__(ind1)[0].unsqueeze(0).cuda())
98-
# emb3 = model_copy(dl.dataset.__getitem__(ind3)[0].unsqueeze(0).cuda())
99-
# activation_map3 = cam_extractor2(torch.dot(emb1.detach().squeeze(0), emb3.squeeze(0)))
100-
# result3, _ = overlay_mask(img3, to_pil_image(activation_map3[0].detach().cpu(), mode='F'), alpha=0.5)
101-
102-
# Display it
103-
fig = plt.figure()
104-
fig.subplots_adjust(top=0.8)
105-
106-
ax = fig.add_subplot(1, 3, 1)
107-
ax.imshow(img1)
108-
ax.title.set_text('Ind = {} \n Class = {}'.format(ind1, dl.dataset.ys[ind1]))
109-
plt.axis('off')
110-
111-
ax = fig.add_subplot(1, 3, 2)
112-
ax.imshow(img2)
113-
ax.title.set_text('Ind = {} \n Class = {}'.format(ind2, dl.dataset.ys[ind2]))
114-
plt.axis('off')
115-
116-
ax = fig.add_subplot(1, 3, 3)
117-
ax.imshow(img3)
118-
ax.title.set_text('Ind = {} \n Class = {}'.format(ind3, dl.dataset.ys[ind3]))
119-
plt.axis('off')
120-
121-
plt.savefig('./{}/{}/{}_{}.png'.format(base_dir, self.dataset_name,
122-
ind1, ind2))
123-
plt.close()
124-
125-
def calc_relabel_dict(self, lookat_harmful, relabel_method,
126-
harmful_indices, helpful_indices, train_nn_indices, train_nn_indices_same_cls,
69+
def calc_relabel_dict(self, lookat_harmful,
70+
harmful_indices, helpful_indices,
12771
base_dir, pair_ind1, pair_ind2):
12872

12973
assert isinstance(lookat_harmful, bool)
130-
assert relabel_method in ['hard', 'soft_knn']
13174
if lookat_harmful:
13275
top_indices = harmful_indices # top_harmful_indices = influence_values.argsort()[-50:]
13376
else:
13477
top_indices = helpful_indices
135-
top_nn_indices = train_nn_indices[top_indices]
136-
top_nn_samecls_indices = train_nn_indices_same_cls[top_indices]
13778

138-
if relabel_method == 'hard': # relabel as its 1st NN
139-
relabel_dict = {}
140-
for kk in range(len(top_indices)):
141-
if self.dl_tr.dataset.ys[top_nn_indices[kk]] != self.dl_tr.dataset.ys[top_nn_samecls_indices[kk]]: # inconsistent label between global NN and same class NN
142-
relabel_dict[top_indices[kk]] = [self.dl_tr.dataset.ys[top_nn_samecls_indices[kk]],
143-
self.dl_tr.dataset.ys[top_nn_indices[kk]]]
144-
with open('./{}/Allrelabeldict_{}_{}.pkl'.format(base_dir, pair_ind1, pair_ind2), 'wb') as handle:
145-
pickle.dump(relabel_dict, handle)
79+
relabel_dict = {}
80+
unique_labels, unique_counts = torch.unique(self.train_label, return_counts=True)
81+
median_shots_percls = unique_counts.median().item()
82+
_, prob_relabel = kNN_label_pred(query_indices=top_indices, embeddings=self.train_embedding, labels=self.train_label,
83+
nb_classes=self.dl_tr.dataset.nb_classes(), knn_k=median_shots_percls)
14684

147-
elif relabel_method == 'soft_knn': # relabel by weighted kNN
148-
relabel_dict = {}
149-
unique_labels, unique_counts = torch.unique(self.train_label, return_counts=True)
150-
median_shots_percls = unique_counts.median().item()
151-
_, prob_relabel = kNN_label_pred(query_indices=top_indices, embeddings=self.train_embedding, labels=self.train_label,
152-
nb_classes=self.dl_tr.dataset.nb_classes(), knn_k=median_shots_percls)
85+
for kk in range(len(top_indices)):
86+
relabel_dict[top_indices[kk]] = prob_relabel[kk].detach().cpu().numpy()
15387

154-
for kk in range(len(top_indices)):
155-
relabel_dict[top_indices[kk]] = prob_relabel[kk].detach().cpu().numpy()
88+
with open('./{}/Allrelabeldict_{}_{}_soft_knn.pkl'.format(base_dir, pair_ind1, pair_ind2), 'wb') as handle:
89+
pickle.dump(relabel_dict, handle)
15690

157-
with open('./{}/Allrelabeldict_{}_{}_soft_knn.pkl'.format(base_dir, pair_ind1, pair_ind2), 'wb') as handle:
158-
pickle.dump(relabel_dict, handle)
159-
160-
161-
else:
162-
raise NotImplemented
16391

16492

16593
if __name__ == '__main__':
@@ -199,7 +127,7 @@ def calc_relabel_dict(self, lookat_harmful, relabel_method,
199127
pair_ind1, pair_ind2 = int(pair_ind1), int(pair_ind2)
200128
if not os.path.exists('./{}/All_influence_{}_{}.npy'.format(base_dir, pair_ind1, pair_ind2)):
201129
# sanity check: # IS.viz_2sample(IS.dl_ev, pair_ind1, pair_ind2)
202-
training_sample_by_influence, influence_values = IS.MC_estimate_forpairs(all_features, [pair_ind1], [pair_ind2])
130+
training_sample_by_influence, influence_values = IS.MC_estimate_forpair(all_features, [pair_ind1], [pair_ind2])
203131
helpful_indices = np.where(influence_values < 0)[0]
204132
harmful_indices = np.where(influence_values > 0)[0]
205133
np.save('./{}/Allhelpful_indices_{}_{}'.format(base_dir, pair_ind1, pair_ind2), helpful_indices)
@@ -216,9 +144,8 @@ def calc_relabel_dict(self, lookat_harmful, relabel_method,
216144
assert len(IS.train_label) == len(train_nn_indices)
217145

218146
'''Step 3: Save harmful indices as well as its neighboring indices'''
219-
IS.calc_relabel_dict(lookat_harmful=lookat_harmful, relabel_method=relabel_method,
147+
IS.calc_relabel_dict(lookat_harmful=lookat_harmful,
220148
harmful_indices=harmful_indices, helpful_indices=helpful_indices,
221-
train_nn_indices=train_nn_indices, train_nn_indices_same_cls=train_nn_indices_same_cls,
222149
base_dir=base_dir, pair_ind1=pair_ind1, pair_ind2=pair_ind2)
223150
exit()
224151

evaluation/neighborhood.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

experiments/EIF_pair_confusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
print('skip')
4040
continue
4141
# sanity check: IS.viz_2sample(IS.dl_ev, wrong_ind, confuse_ind)
42-
mean_deltaL_deltaD = IS.MC_estimate_forpairs([wrong_ind, confuse_ind], num_thetas=1, steps=50)
42+
mean_deltaL_deltaD = IS.MC_estimate_forpair([wrong_ind, confuse_ind], num_thetas=1, steps=50)
4343

4444
influence_values = np.asarray(mean_deltaL_deltaD)
4545
training_sample_by_influence = influence_values.argsort() # ascending

experiments/EIFvsIF_mislabel_evaluation.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020
# loss_type = 'SoftTriple_noisy_{}'.format(noisy_level); dataset_name = 'cars_noisy'; config_name = 'cars'; seed = 4
2121
loss_type = 'SoftTriple_noisy_{}'.format(noisy_level); dataset_name = 'inshop_noisy'; config_name = 'inshop'; seed = 3
2222

23-
'''============ Our Influence function =================='''
23+
'''============================================= Our Empirical Influence function =============================================================='''
2424
IS = MCScalableIF(dataset_name, seed, loss_type, config_name, test_crop)
2525
basedir = 'MislabelExp_Influential_data'
2626
os.makedirs(basedir, exist_ok=True)
2727

2828
for num_thetas in [1, 2, 3]:
2929

3030
'''Mislabelled data detection'''
31-
if os.path.exists("{}/{}_{}_helpful_testcls{}_SIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level)):
32-
helpful_indices = np.load("{}/{}_{}_helpful_testcls{}_SIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
33-
harmful_indices = np.load("{}/{}_{}_harmful_testcls{}_SIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
34-
influence_values = np.load("{}/{}_{}_influence_values_testcls{}_SIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
31+
if os.path.exists("{}/{}_{}_helpful_testcls{}_EIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level)):
32+
helpful_indices = np.load("{}/{}_{}_helpful_testcls{}_EIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
33+
harmful_indices = np.load("{}/{}_{}_harmful_testcls{}_EIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
34+
influence_values = np.load("{}/{}_{}_influence_values_testcls{}_EIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level))
3535
else:
3636
confusion_class_pairs = IS.get_confusion_class_pairs()
3737

@@ -46,9 +46,9 @@
4646

4747
helpful_indices = np.where(influence_values < 0)[0] # cache all helpful
4848
harmful_indices = np.where(influence_values > 0)[0] # cache all harmful
49-
np.save("{}/{}_{}_helpful_testcls{}_SIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), helpful_indices)
50-
np.save("{}/{}_{}_harmful_testcls{}_SIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), harmful_indices)
51-
np.save("{}/{}_{}_influence_values_testcls{}_SIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), influence_values)
49+
np.save("{}/{}_{}_helpful_testcls{}_EIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), helpful_indices)
50+
np.save("{}/{}_{}_harmful_testcls{}_EIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), harmful_indices)
51+
np.save("{}/{}_{}_influence_values_testcls{}_EIF_theta{}_{}".format(basedir, IS.dataset_name, IS.loss_type, 0, num_thetas, noisy_level), influence_values)
5252

5353
training_sample_by_influence = influence_values.argsort() # ascending, harmful first
5454
# mislabelled indices ground-truth
@@ -68,7 +68,7 @@
6868
# TODO climbing plot
6969
'''Weighted KNN'''
7070
start_time = time.time()
71-
harmful_indices = np.load("{}/{}_{}_harmful_testcls{}_SIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, 1, noisy_level))
71+
harmful_indices = np.load("{}/{}_{}_harmful_testcls{}_EIF_theta{}_{}.npy".format(basedir, IS.dataset_name, IS.loss_type, 0, 1, noisy_level))
7272
relabel_dict = {}
7373
unique_labels, unique_counts = torch.unique(IS.train_label, return_counts=True)
7474
median_shots_percls = unique_counts.median().item()
@@ -92,8 +92,9 @@
9292
ct_correct += 1
9393
print(ct_correct, total_ct)
9494

95+
'''======================================================================================================================================='''
9596

96-
'''============ Original Influence function =================='''
97+
'''======================================== Original Influence function =========================================================================='''
9798
IS = OrigIF(dataset_name, seed, loss_type, config_name, test_crop)
9899
basedir = 'MislabelExp_Influential_data'
99100
os.makedirs(basedir, exist_ok=True)
@@ -138,7 +139,9 @@
138139

139140
plt.plot(cum_overlap, label='IF')
140141

141-
'''Relabelled data accuracy (only relabel harmful)'''
142+
'''======================================================================================================================================='''
143+
144+
'''=============================================Random================================================================'''
142145
overlap = np.isin(np.arange(len(IS.dl_tr.dataset)), gt_mislabelled_indices)
143146
cum_overlap = np.cumsum(overlap)
144147

@@ -147,6 +150,7 @@
147150
plt.tight_layout()
148151
plt.savefig('./images/mislabel_{}_{}_alltheta_noisylevel{}.pdf'.format(dataset_name, loss_type, noisy_level),
149152
bbox_inches='tight')
150-
# plt.savefig('./images/mislabel_{}_{}_alltheta_noisylevel{}.png'.format(dataset_name, loss_type, noisy_level),
151-
# bbox_inches='tight')
153+
154+
'''======================================================================================================================================='''
155+
152156

experiments/IF_group_confusion.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@
4646
helpful_indices)
4747
np.save("Influential_data_baselines/{}_{}_harmful_testcls{}".format(IS.dataset_name, IS.loss_type, pair_idx),
4848
harmful_indices)
49-
exit()
5049

5150
'''Actually train with downweighted harmful and upweighted helpful training'''
5251
os.system("./scripts/run_{}_IF_{}.sh".format(dataset_name, loss_type))
53-
exit()
5452

5553
'''Other: get confusion (before VS after)'''
5654
IS.model = IS._load_model() # reload the original weights
@@ -75,5 +73,3 @@
7573
inter_dist_after, _ = grad_confusion(IS.model, features, wrong_cls, confuse_classes,
7674
IS.testing_nn_label, IS.testing_label, IS.testing_nn_indices)
7775
print("After d(G_p): ", inter_dist_after)
78-
79-
exit()

0 commit comments

Comments
 (0)