@@ -66,100 +66,28 @@ def getNN_indices(self, embedding, label):
66
66
67
67
return nn_indices , nn_label , nn_indices_same_cls
68
68
69
- def vis_pairs (self , wrong_indices , confuse_indices , wrong_samecls_indices ,
70
- dl , base_dir = 'Grad_Test' ):
71
- '''Visualize all confusion pairs'''
72
- assert len (wrong_indices ) == len (confuse_indices )
73
- assert len (wrong_indices ) == len (wrong_samecls_indices )
74
-
75
- os .makedirs ('./{}/{}' .format (base_dir , self .dataset_name ), exist_ok = True )
76
- model_copy = self ._load_model ()
77
- model_copy .eval ()
78
-
79
- for ind1 , ind2 , ind3 in zip (wrong_indices , confuse_indices , wrong_samecls_indices ):
80
- # cam_extractor1 = GradCAMCustomize(model_copy, target_layer=model_copy.module[0].base.layer4) # to last layer
81
- # cam_extractor2 = GradCAMCustomize(model_copy, target_layer=model_copy.module[0].base.layer4) # to last layer
82
-
83
- # Get the two embeddings first
84
- img1 = to_pil_image (read_image (dl .dataset .im_paths [ind1 ]))
85
- img2 = to_pil_image (read_image (dl .dataset .im_paths [ind2 ]))
86
- img3 = to_pil_image (read_image (dl .dataset .im_paths [ind3 ]))
87
-
88
- # cam_extractor1._hooks_enabled = True
89
- # model_copy.zero_grad()
90
- # emb1 = model_copy(dl.dataset.__getitem__(ind1)[0].unsqueeze(0).cuda())
91
- # emb2 = model_copy(dl.dataset.__getitem__(ind2)[0].unsqueeze(0).cuda())
92
- # activation_map2 = cam_extractor1(torch.dot(emb1.detach().squeeze(0), emb2.squeeze(0)))
93
- # result2, _ = overlay_mask(img2, to_pil_image(activation_map2[0].detach().cpu(), mode='F'), alpha=0.5)
94
- #
95
- # cam_extractor2._hooks_enabled = True
96
- # model_copy.zero_grad()
97
- # emb1 = model_copy(dl.dataset.__getitem__(ind1)[0].unsqueeze(0).cuda())
98
- # emb3 = model_copy(dl.dataset.__getitem__(ind3)[0].unsqueeze(0).cuda())
99
- # activation_map3 = cam_extractor2(torch.dot(emb1.detach().squeeze(0), emb3.squeeze(0)))
100
- # result3, _ = overlay_mask(img3, to_pil_image(activation_map3[0].detach().cpu(), mode='F'), alpha=0.5)
101
-
102
- # Display it
103
- fig = plt .figure ()
104
- fig .subplots_adjust (top = 0.8 )
105
-
106
- ax = fig .add_subplot (1 , 3 , 1 )
107
- ax .imshow (img1 )
108
- ax .title .set_text ('Ind = {} \n Class = {}' .format (ind1 , dl .dataset .ys [ind1 ]))
109
- plt .axis ('off' )
110
-
111
- ax = fig .add_subplot (1 , 3 , 2 )
112
- ax .imshow (img2 )
113
- ax .title .set_text ('Ind = {} \n Class = {}' .format (ind2 , dl .dataset .ys [ind2 ]))
114
- plt .axis ('off' )
115
-
116
- ax = fig .add_subplot (1 , 3 , 3 )
117
- ax .imshow (img3 )
118
- ax .title .set_text ('Ind = {} \n Class = {}' .format (ind3 , dl .dataset .ys [ind3 ]))
119
- plt .axis ('off' )
120
-
121
- plt .savefig ('./{}/{}/{}_{}.png' .format (base_dir , self .dataset_name ,
122
- ind1 , ind2 ))
123
- plt .close ()
124
-
125
- def calc_relabel_dict (self , lookat_harmful , relabel_method ,
126
- harmful_indices , helpful_indices , train_nn_indices , train_nn_indices_same_cls ,
69
+ def calc_relabel_dict (self , lookat_harmful ,
70
+ harmful_indices , helpful_indices ,
127
71
base_dir , pair_ind1 , pair_ind2 ):
128
72
129
73
assert isinstance (lookat_harmful , bool )
130
- assert relabel_method in ['hard' , 'soft_knn' ]
131
74
if lookat_harmful :
132
75
top_indices = harmful_indices # top_harmful_indices = influence_values.argsort()[-50:]
133
76
else :
134
77
top_indices = helpful_indices
135
- top_nn_indices = train_nn_indices [top_indices ]
136
- top_nn_samecls_indices = train_nn_indices_same_cls [top_indices ]
137
78
138
- if relabel_method == 'hard' : # relabel as its 1st NN
139
- relabel_dict = {}
140
- for kk in range (len (top_indices )):
141
- if self .dl_tr .dataset .ys [top_nn_indices [kk ]] != self .dl_tr .dataset .ys [top_nn_samecls_indices [kk ]]: # inconsistent label between global NN and same class NN
142
- relabel_dict [top_indices [kk ]] = [self .dl_tr .dataset .ys [top_nn_samecls_indices [kk ]],
143
- self .dl_tr .dataset .ys [top_nn_indices [kk ]]]
144
- with open ('./{}/Allrelabeldict_{}_{}.pkl' .format (base_dir , pair_ind1 , pair_ind2 ), 'wb' ) as handle :
145
- pickle .dump (relabel_dict , handle )
79
+ relabel_dict = {}
80
+ unique_labels , unique_counts = torch .unique (self .train_label , return_counts = True )
81
+ median_shots_percls = unique_counts .median ().item ()
82
+ _ , prob_relabel = kNN_label_pred (query_indices = top_indices , embeddings = self .train_embedding , labels = self .train_label ,
83
+ nb_classes = self .dl_tr .dataset .nb_classes (), knn_k = median_shots_percls )
146
84
147
- elif relabel_method == 'soft_knn' : # relabel by weighted kNN
148
- relabel_dict = {}
149
- unique_labels , unique_counts = torch .unique (self .train_label , return_counts = True )
150
- median_shots_percls = unique_counts .median ().item ()
151
- _ , prob_relabel = kNN_label_pred (query_indices = top_indices , embeddings = self .train_embedding , labels = self .train_label ,
152
- nb_classes = self .dl_tr .dataset .nb_classes (), knn_k = median_shots_percls )
85
+ for kk in range (len (top_indices )):
86
+ relabel_dict [top_indices [kk ]] = prob_relabel [kk ].detach ().cpu ().numpy ()
153
87
154
- for kk in range ( len ( top_indices )) :
155
- relabel_dict [ top_indices [ kk ]] = prob_relabel [ kk ]. detach (). cpu (). numpy ( )
88
+ with open ( './{}/Allrelabeldict_{}_{}_soft_knn.pkl' . format ( base_dir , pair_ind1 , pair_ind2 ), 'wb' ) as handle :
89
+ pickle . dump ( relabel_dict , handle )
156
90
157
- with open ('./{}/Allrelabeldict_{}_{}_soft_knn.pkl' .format (base_dir , pair_ind1 , pair_ind2 ), 'wb' ) as handle :
158
- pickle .dump (relabel_dict , handle )
159
-
160
-
161
- else :
162
- raise NotImplemented
163
91
164
92
165
93
if __name__ == '__main__' :
@@ -199,7 +127,7 @@ def calc_relabel_dict(self, lookat_harmful, relabel_method,
199
127
pair_ind1 , pair_ind2 = int (pair_ind1 ), int (pair_ind2 )
200
128
if not os .path .exists ('./{}/All_influence_{}_{}.npy' .format (base_dir , pair_ind1 , pair_ind2 )):
201
129
# sanity check: # IS.viz_2sample(IS.dl_ev, pair_ind1, pair_ind2)
202
- training_sample_by_influence , influence_values = IS .MC_estimate_forpairs (all_features , [pair_ind1 ], [pair_ind2 ])
130
+ training_sample_by_influence , influence_values = IS .MC_estimate_forpair (all_features , [pair_ind1 ], [pair_ind2 ])
203
131
helpful_indices = np .where (influence_values < 0 )[0 ]
204
132
harmful_indices = np .where (influence_values > 0 )[0 ]
205
133
np .save ('./{}/Allhelpful_indices_{}_{}' .format (base_dir , pair_ind1 , pair_ind2 ), helpful_indices )
@@ -216,9 +144,8 @@ def calc_relabel_dict(self, lookat_harmful, relabel_method,
216
144
assert len (IS .train_label ) == len (train_nn_indices )
217
145
218
146
'''Step 3: Save harmful indices as well as its neighboring indices'''
219
- IS .calc_relabel_dict (lookat_harmful = lookat_harmful , relabel_method = relabel_method ,
147
+ IS .calc_relabel_dict (lookat_harmful = lookat_harmful ,
220
148
harmful_indices = harmful_indices , helpful_indices = helpful_indices ,
221
- train_nn_indices = train_nn_indices , train_nn_indices_same_cls = train_nn_indices_same_cls ,
222
149
base_dir = base_dir , pair_ind1 = pair_ind1 , pair_ind2 = pair_ind2 )
223
150
exit ()
224
151
0 commit comments