Skip to content

Commit a7bc376

Browse files
committed
add files
1 parent a4357b2 commit a7bc376

File tree

3 files changed

+238
-1
lines changed

3 files changed

+238
-1
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1-
# pytorch-skipthoughts
1+
# pytorch-skipthoughts
2+
3+
Simple, plug & play pytorch implementation of a skipthoughts encoder.
4+
Ported from [Theano based Sent2Vec encoder](https://github.com/ryankiros/skip-thoughts), based on the paper [Skip-Thought Vectors](http://arxiv.org/abs/1506.06726).
5+
6+
## To do
7+
8+
* add support for vocabulary extension
9+
* implement decoders
10+
* add training script for custom model creation

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy==1.18.5
2+
nltk==3.5
3+
torch==1.9.1

skipthoughts.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import sys
2+
import numpy
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
import torch.nn.functional as F
8+
9+
from collections import OrderedDict
10+
11+
import nltk
12+
from nltk.tokenize import word_tokenize
13+
14+
import logging
15+
16+
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s")
17+
consoleHandler = logging.StreamHandler(sys.stdout)
18+
consoleHandler.setFormatter(logFormatter)
19+
logger = logging.getLogger()
20+
logger.addHandler(consoleHandler)
21+
logger.setLevel(logging.NOTSET)
22+
23+
24+
25+
26+
class SkipThoughts(nn.Module):
27+
def __init__(self, dirStr: str, dictionary: dict, fixedEmb: bool = False, normalized: bool = True):
28+
super(SkipThoughts, self).__init__()
29+
30+
self.dirStr = dirStr
31+
self.fixed_emb = fixedEmb
32+
self.normalized = normalized
33+
self.dictionary = dictionary
34+
35+
def preprocess(self, x):
36+
X = []
37+
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
38+
for t in x:
39+
sents = sent_detector.tokenize(t)
40+
result = ''
41+
for s in sents:
42+
tokens = word_tokenize(s)
43+
result += ' ' + ' '.join(tokens)
44+
X.append(result)
45+
46+
wordIdx = [[self.dictionary[word] for word in s.split()] for s in X]
47+
48+
tensorWordIdx = torch.zeros(len(wordIdx), max([len(i) for i in wordIdx])) # needs numpy base for large batches
49+
for i in range(len(tensorWordIdx)):
50+
tensorWordIdx[i,:len(wordIdx[i])] = torch.tensor(wordIdx[i], dtype=torch.int64)
51+
52+
return tensorWordIdx.long()
53+
54+
def loadEmbedding(self, dictionary: dict, filePath: str):
55+
logging.info(f"Loading table: {filePath}")
56+
embedding = nn.Embedding(num_embeddings=len(self.dictionary) + 1,
57+
embedding_dim=620,
58+
padding_idx=0,
59+
sparse=False)
60+
61+
parameters = numpy.load(filePath, encoding='latin1', allow_pickle=True)
62+
weights = torch.zeros(len(dictionary) + 1, 620)
63+
for i in range(len(weights) - 1):
64+
weights[i + 1] = torch.from_numpy(parameters[i])
65+
embedding.load_state_dict({'weight': weights})
66+
return embedding
67+
68+
69+
class UniSkipThoughts(SkipThoughts):
70+
def __init__(self, dirStr: str, dictionary: dict, dropout: float = 0, fixedEmb: bool = False, normalized: bool = True):
71+
super(UniSkipThoughts, self).__init__(dirStr, dictionary, fixedEmb, normalized)
72+
self.dropout = dropout
73+
74+
self.embedding = self.loadEmbedding(self.dictionary, dirStr + '/utable.npy')
75+
76+
if fixedEmb:
77+
self.embedding.weight.requires_grad = False
78+
79+
self.gru = nn.GRU(input_size=620,
80+
hidden_size=2400,
81+
batch_first=True,
82+
dropout=self.dropout)
83+
self.loadModel(dirStr + "/uni_skip.npz")
84+
85+
def selectResult(self, x, lengths):
86+
X = torch.zeros(x.size(0), 2400)
87+
for i in range(len(x)):
88+
X[i] = x[i][lengths[i]-1]
89+
return X
90+
91+
def loadModel(self, modelPath: str):
92+
logging.info(f"Loading model: {modelPath}")
93+
params = numpy.load(modelPath, encoding='latin1', allow_pickle=True)
94+
states = OrderedDict()
95+
states['bias_ih_l0'] = torch.zeros(7200)
96+
states['bias_hh_l0'] = torch.zeros(7200)
97+
states['weight_ih_l0'] = torch.zeros(7200, 620)
98+
states['weight_hh_l0'] = torch.zeros(7200, 2400)
99+
states['weight_ih_l0'][:4800] = torch.from_numpy(params['encoder_W']).t()
100+
states['weight_ih_l0'][4800:] = torch.from_numpy(params['encoder_Wx']).t()
101+
states['bias_ih_l0'][:4800] = torch.from_numpy(params['encoder_b'])
102+
states['bias_ih_l0'][4800:] = torch.from_numpy(params['encoder_bx'])
103+
states['weight_hh_l0'][:4800] = torch.from_numpy(params['encoder_U']).t()
104+
states['weight_hh_l0'][4800:] = torch.from_numpy(params['encoder_Ux']).t()
105+
self.gru.load_state_dict(states)
106+
107+
def forward(self, input):
108+
lengths = [len(s.split(' ')) for s in input]
109+
input = self.preprocess(input)
110+
x = self.embedding(input)
111+
y, hn = self.gru(x)
112+
y = self.selectResult(y, lengths)
113+
if self.normalized:
114+
y = torch.nn.functional.normalize(y)
115+
return y
116+
117+
118+
class BiSkipThoughts(SkipThoughts):
119+
120+
def __init__(self, dirStr: str, dictionary: dict, dropout: float = 0, fixedEmb: bool = False, normalized: bool = True):
121+
super(BiSkipThoughts, self).__init__(dirStr, dictionary, fixedEmb, normalized)
122+
self.dropout = dropout
123+
124+
self.embedding = self.loadEmbedding(self.dictionary, dirStr + '/btable.npy')
125+
126+
if fixedEmb:
127+
self.embedding.weight.requires_grad = False
128+
129+
self.gru = nn.GRU(input_size=620,
130+
hidden_size=1200,
131+
batch_first=True,
132+
dropout=self.dropout,
133+
bidirectional=True)
134+
135+
self.loadModel(dirStr + "/bi_skip.npz")
136+
137+
138+
def loadModel(self, modelPath: str):
139+
logging.info(f"Loading model: {modelPath}")
140+
params = numpy.load(modelPath, encoding='latin1', allow_pickle=True)
141+
states = OrderedDict()
142+
states['bias_ih_l0'] = torch.zeros(3600)
143+
states['bias_hh_l0'] = torch.zeros(3600) # must stay equal to 0
144+
states['weight_ih_l0'] = torch.zeros(3600, 620)
145+
states['weight_hh_l0'] = torch.zeros(3600, 1200)
146+
147+
states['bias_ih_l0_reverse'] = torch.zeros(3600)
148+
states['bias_hh_l0_reverse'] = torch.zeros(3600) # must stay equal to 0
149+
states['weight_ih_l0_reverse'] = torch.zeros(3600, 620)
150+
states['weight_hh_l0_reverse'] = torch.zeros(3600, 1200)
151+
152+
states['weight_ih_l0'][:2400] = torch.from_numpy(params['encoder_W']).t()
153+
states['weight_ih_l0'][2400:] = torch.from_numpy(params['encoder_Wx']).t()
154+
states['bias_ih_l0'][:2400] = torch.from_numpy(params['encoder_b'])
155+
states['bias_ih_l0'][2400:] = torch.from_numpy(params['encoder_bx'])
156+
states['weight_hh_l0'][:2400] = torch.from_numpy(params['encoder_U']).t()
157+
states['weight_hh_l0'][2400:] = torch.from_numpy(params['encoder_Ux']).t()
158+
159+
states['weight_ih_l0_reverse'][:2400] = torch.from_numpy(params['encoder_r_W']).t()
160+
states['weight_ih_l0_reverse'][2400:] = torch.from_numpy(params['encoder_r_Wx']).t()
161+
states['bias_ih_l0_reverse'][:2400] = torch.from_numpy(params['encoder_r_b'])
162+
states['bias_ih_l0_reverse'][2400:] = torch.from_numpy(params['encoder_r_bx'])
163+
states['weight_hh_l0_reverse'][:2400] = torch.from_numpy(params['encoder_r_U']).t()
164+
states['weight_hh_l0_reverse'][2400:] = torch.from_numpy(params['encoder_r_Ux']).t()
165+
self.gru.load_state_dict(states)
166+
167+
def forward(self, input):
168+
lengths = [len(s.split(' ')) for s in input]
169+
170+
x = self.preprocess(input)
171+
x = self.embedding(x)
172+
173+
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
174+
175+
y, hn = self.gru(x)
176+
177+
hn = hn.transpose(0, 1).contiguous()
178+
hn = hn.view(len(input), 2 * hn.size(2))
179+
180+
if self.normalized:
181+
hn = torch.nn.functional.normalize(hn)
182+
183+
return hn
184+
185+
186+
class Encoder(object):
187+
def __init__(self, dirStr: str, dropout: float = 0, fixedEmb: bool = False, normalized: bool = True):
188+
self.dirStr = dirStr
189+
self.dropout = dropout
190+
self.fixedEmb = fixedEmb
191+
self.normalized = normalized
192+
self.dictionary = self.loadDictionary(dirStr)
193+
self.uniSkip = UniSkipThoughts(dirStr, self.dictionary, dropout, fixedEmb, normalized)
194+
self.biSkip = BiSkipThoughts(dirStr, self.dictionary, dropout, fixedEmb, normalized)
195+
196+
def loadDictionary(self, dirStr: str):
197+
logging.info("Loading dictionary")
198+
with open(dirStr + '/dictionary.txt', 'r', encoding="utf8") as file:
199+
words = file.readlines()
200+
201+
dictionary = {}
202+
for idx, word in enumerate(words):
203+
dictionary[word.strip()] = idx + 1
204+
return dictionary
205+
206+
def encode(self, input: list):
207+
uFeatures = self.uniSkip(input)
208+
bFeatures = self.biSkip(input)
209+
return torch.cat([uFeatures, bFeatures], 1)
210+
211+
if __name__ == '__main__':
212+
213+
dirStr = 'models'
214+
215+
encoder = Encoder(dirStr)
216+
217+
test = ["Hey, how are you?", "This sentence is a lie"]
218+
219+
result = encoder.encode(test)
220+
221+
print(result)
222+
223+
224+
225+

0 commit comments

Comments
 (0)