Skip to content

Commit 7ab5b89

Browse files
committed
small fixes + formatting
1 parent 4bdb6b4 commit 7ab5b89

9 files changed

+42
-252
lines changed

__pycache__/evaluate.cpython-36.pyc

0 Bytes
Binary file not shown.

__pycache__/model_grc.cpython-36.pyc

-196 Bytes
Binary file not shown.

__pycache__/model_srnn.cpython-36.pyc

-210 Bytes
Binary file not shown.

evaluate.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
from sklearn.metrics import f1_score, accuracy_score
77

8+
89
def evaluate (net, x_data, y_data, seg_ind, batched_len_list, opt):
910
net.eval()
1011
batch_size_eval = opt.batch_size_eval
@@ -53,14 +54,11 @@ def evaluate (net, x_data, y_data, seg_ind, batched_len_list, opt):
5354

5455

5556
hidden = net.init_hidden(bs)
56-
5757
output_2d, hidden = net( x_batch_s, hidden, sorted_vals )
5858

5959
pack_y = torch.nn.utils.rnn.pack_padded_sequence(y_batch_s, sorted_vals)
6060
unpacked_y, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(pack_y)
6161

62-
63-
6462
sorted_vals = torch.LongTensor(sorted_vals)
6563
if opt.USE_CUDA == True:
6664
sorted_vals = sorted_vals.cuda()
@@ -83,12 +81,9 @@ def evaluate (net, x_data, y_data, seg_ind, batched_len_list, opt):
8381
all_labels_paths.extend( sent_batch_labels )
8482
all_seg_inds.extend(sent_batch_seg_inds)
8583

86-
87-
8884
print('Evaluating batch', i)
8985

9086
new_segments = []
91-
9287
for segment in all_segments:
9388
temp_seg = [(0,segment[0]-1)]
9489
for j, val in enumerate(segment[1:],1):
@@ -101,7 +96,7 @@ def evaluate (net, x_data, y_data, seg_ind, batched_len_list, opt):
10196
all_seg_inds,
10297
new_segments)
10398

104-
#Flatten to calculate f1 score on one run for word level
99+
#Flatten to calculate acc score on one run for word level
105100
all_words_labels_flat = list(itertools.chain.from_iterable(all_words_labels))
106101
all_words_preds_flat = list(itertools.chain.from_iterable(all_words_preds))
107102

@@ -111,12 +106,10 @@ def evaluate (net, x_data, y_data, seg_ind, batched_len_list, opt):
111106
print ('F1 Tokenization', F1_tok)
112107
print ('Word level Accuracy: ' , word_acc)
113108

114-
# print ('Character level F1 score: ' , f1_char)
115109
return F1_pos_seg
116110

117111

118112
def convert_to_word(all_labels_paths, all_char_paths, seg_ind_s, segments_predicted):
119-
120113
word_2d_labels = []
121114
word_2d_preds = []
122115
count_correct_pos_seg = 0
@@ -133,7 +126,6 @@ def convert_to_word(all_labels_paths, all_char_paths, seg_ind_s, segments_predic
133126
start_ind = j+1
134127
idx_list.append(word_range)
135128

136-
137129
char_seg = all_char_paths[i]
138130
segments = [ char_seg[s:(e+1)] for s,e in idx_list]
139131
word_2d_preds.append( [ Counter(seg).most_common()[0][0] for seg in segments] )
@@ -160,9 +152,4 @@ def convert_to_word(all_labels_paths, all_char_paths, seg_ind_s, segments_predic
160152
pos_seg_recall = count_correct_pos_seg / total_clean_tokens
161153
F1_pos_seg = (2 * pos_seg_prec * pos_seg_recall) / (pos_seg_prec + pos_seg_recall)
162154

163-
164-
# print ('Tokenization recall' , token_recall)
165-
# print ('Tokenization precision' , token_prec)
166-
# print ('F1 Score Tokenization', F1_tok )
167-
# print ('F1 Score POS & Seg', F1_pos_seg)
168155
return F1_pos_seg, F1_tok, word_2d_labels, word_2d_preds

0 commit comments

Comments
 (0)