Skip to content

Commit 5b30618

Browse files
committed
added docstring and unit tests
1 parent d4bed09 commit 5b30618

File tree

5 files changed

+171
-10
lines changed

5 files changed

+171
-10
lines changed

shorttext/spell/basespellcorrector.py

+17
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,25 @@
22
import shorttext.utils.classification_exceptions as ce
33

44
class SpellCorrector:
5+
""" Base class for all spell corrector.
6+
7+
This class is not implemented; this can be seen as an "abstract class."
8+
9+
"""
510
def train(self, text):
11+
""" Train the spell corrector with the given corpus.
12+
13+
:param text: training corpus
14+
:type text: str
15+
"""
616
raise ce.NotImplementedException()
717

818
def correct(self, word):
19+
""" Recommend a spell correction to given the word.
20+
21+
:param word: word to be checked
22+
:return: recommended correction
23+
:type word: str
24+
:rtype: str
25+
"""
926
return word

shorttext/spell/binarize.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
default_specialsignals = {'eos': '#', 'unk': '_', 'number': '@'}
1212
default_signaldenotions = {'<eos>': 'eos', '<unk>': 'unk'}
1313

14-
## TODO: need to refine the array settings
1514

1615
class SpellingToConcatCharVecEncoder:
1716
def __init__(self, alph):
@@ -30,6 +29,11 @@ def hasnum(word):
3029

3130

3231
class SCRNNBinarizer:
32+
""" A class used by Sakaguchi's spell corrector to convert text into numerical vectors.
33+
34+
No documentation for this class.
35+
36+
"""
3337
def __init__(self, alpha, signalchar_dict):
3438
self.signalchar_dict = signalchar_dict
3539
self.concatchar_encoder = SpellingToConcatCharVecEncoder(alpha)

shorttext/spell/norvig.py

+39
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,62 @@
88
from .editor import compute_set_edits1, compute_set_edits2
99

1010
class NorvigSpellCorrector(SpellCorrector):
11+
""" Spell corrector described by Peter Norvig in his blog. (https://norvig.com/spell-correct.html)
12+
13+
"""
1114
def __init__(self):
15+
""" Instantiate the class
16+
17+
"""
1218
self.train('')
1319

1420
def train(self, text):
21+
""" Given the text, train the spell corrector.
22+
23+
:param text: training corpus
24+
:type text: str
25+
"""
1526
self.words = re.findall(r'\w+', text.lower())
1627
self.WORDS = Counter(self.words)
1728
self.N = sum(self.WORDS.values())
1829

1930
def P(self, word):
31+
""" Compute the probability of the words randomly sampled from the training corpus.
32+
33+
:param word: a word
34+
:return: probability of the word sampled randomly in the corpus
35+
:type word: str
36+
:rtype: float
37+
"""
2038
return self.WORDS[word] / float(self.N)
2139

2240
def correct(self, word):
41+
""" Recommend a spelling correction to the given word
42+
43+
:param word: a word
44+
:return: recommended correction
45+
:type word: str
46+
:rtype: str
47+
"""
2348
return max(self.candidates(word), key=self.P)
2449

2550
def known(self, words):
51+
""" Filter away the words that are not found in the training corpus.
52+
53+
:param words: list of words
54+
:return: list of words that can be found in the training corpus
55+
:type words: list
56+
:rtype: list
57+
"""
2658
return set(w for w in words if w in self.WORDS)
2759

2860
def candidates(self, word):
61+
""" List potential candidates for corrected spelling to the given words.
62+
63+
:param word: a word
64+
:return: list of recommended corrections
65+
:type word: str
66+
:rtype: list
67+
"""
2968
return (self.known([word]) or self.known(compute_set_edits1(word)) or self.known(compute_set_edits2(word)) or [word])
3069

shorttext/spell/sakaguchi.py

+68-9
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,34 @@
1818

1919

2020
class SCRNNSpellCorrector(SpellCorrector):
21+
""" scRNN (semi-character-level recurrent neural network) Spell Corrector.
22+
23+
Reference:
24+
Keisuke Sakaguchi, Kevin Duh, Matt Post, Benjamin Van Durme, "Robsut Wrod Reocginiton via semi-Character Recurrent Neural Networ," arXiv:1608.02214 (2016). [`arXiv
25+
<https://arxiv.org/abs/1608.02214>`_]
26+
27+
"""
2128
def __init__(self, operation,
2229
alph=default_alph,
2330
specialsignals=default_specialsignals,
2431
concatcharvec_encoder=None,
2532
batchsize=1,
2633
nb_hiddenunits=650):
34+
""" Instantiate the scRNN spell corrector.
35+
36+
:param operation: types of distortion of words in training (options: "NOISE-INSERT", "NOISE-DELETE", "NOISE-REPLACE", "JUMBLE-WHOLE", "JUMBLE-BEG", "JUMBLE-END", and "JUMBLE-INT")
37+
:param alph: default string of characters (Default: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz.,:;'*!?`$%&(){}[]-/\@_#")
38+
:param specialsignals: dictionary of special signals (Default built-in)
39+
:param concatcharvec_encoder: one-hot encoder for characters, initialize if None. (Default: None)
40+
:param batchsize: batch size. (Default: 1)
41+
:param nb_hiddenunits: number of hidden units. (Default: 650)
42+
:type operation: str
43+
:type alpha: str
44+
:type specialsignals: dict
45+
:type concatcharvec_encoder: shorttext.spell.binarize.SpellingToConcatCharVecEncoder
46+
:type batchsize: int
47+
:type nb_hiddenunits: int
48+
"""
2749
self.operation = operation
2850
self.binarizer = SCRNNBinarizer(alph, specialsignals)
2951
self.concatcharvec_encoder = SpellingToConcatCharVecEncoder(alph) if concatcharvec_encoder==None else concatcharvec_encoder
@@ -33,6 +55,13 @@ def __init__(self, operation,
3355
self.nb_hiddenunits = nb_hiddenunits
3456

3557
def preprocess_text_train(self, text):
58+
""" A generator that output numpy vectors for the text for training.
59+
60+
:param text: text
61+
:return: generator that outputs the numpy vectors for training
62+
:type text: str
63+
:rtype: generator
64+
"""
3665
for token in nospace_tokenize(text):
3766
if self.operation.upper().startswith('NOISE'):
3867
xvec, _ = self.binarizer.noise_char(token, self.operation.upper()[6:])
@@ -43,13 +72,34 @@ def preprocess_text_train(self, text):
4372
yield xvec, yvec
4473

4574
def preprocess_text_correct(self, text):
75+
""" A generator that output numpy vectors for the text for correction.
76+
77+
ModelNotTrainedException is raised if the model has not been trained.
78+
79+
:param text: text
80+
:return: generator that outputs the numpy vectors for correction
81+
:type text: str
82+
:rtype: generator
83+
:raise: ModelNotTrainedException
84+
"""
4685
if not self.trained:
4786
raise ce.ModelNotTrainedException()
4887
for token in nospace_tokenize(text):
4988
xvec, _ = self.binarizer.change_nothing(token, self.operation)
5089
yield xvec
5190

52-
def train(self, text, nb_epoch=100, optimizer='rmsprop'):
91+
def train(self, text, nb_epoch=100, dropout_rate=0.01, optimizer='rmsprop'):
92+
""" Train the scRNN model.
93+
94+
:param text: training corpus
95+
:param nb_epoch: number of epochs (Default: 100)
96+
:param dropout_rate: dropout rate (Default: 0.01)
97+
:param optimizer: optimizer (Default: "rmsprop")
98+
:type text: str
99+
:type nb_epoch: int
100+
:type dropout_rate: float
101+
:type optimizer: str
102+
"""
53103
self.dictionary = Dictionary([nospace_tokenize(text), default_specialsignals.values()])
54104
self.onehotencoder.fit(np.arange(len(self.dictionary)).reshape((len(self.dictionary), 1)))
55105
xylist = [(xvec.transpose(), yvec.transpose()) for xvec, yvec in self.preprocess_text_train(text)]
@@ -59,26 +109,35 @@ def train(self, text, nb_epoch=100, optimizer='rmsprop'):
59109
# neural network here
60110
model = Sequential()
61111
model.add(LSTM(self.nb_hiddenunits, return_sequences=True, batch_input_shape=(None, self.batchsize, len(self.concatcharvec_encoder)*3)))
62-
model.add(Dropout(0.01))
112+
model.add(Dropout(dropout_rate))
63113
model.add(TimeDistributed(Dense(len(self.dictionary))))
64114
model.add(Activation('softmax'))
65115

66116
# compile... more arguments
67-
model.compile(loss='categorical_crossentropy', optimizer=optimizer
68-
#metrics=['accuracy'])
69-
)
70-
71-
print xtrain.shape
72-
print ytrain.shape
117+
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
73118

119+
# training
74120
model.fit(xtrain, ytrain, epochs=nb_epoch)
75121

76122
self.model = model
77123
self.trained = True
78124

79125
def correct(self, word):
126+
""" Recommend a spell correction to given the word.
127+
128+
:param word: a given word
129+
:return: recommended correction
130+
:type word: str
131+
:rtype: str
132+
"""
80133
xmat = np.array([xvec.transpose() for xvec in self.preprocess_text_correct(word)])
81134
yvec = self.model.predict(xmat)
82135

83136
maxy = yvec.argmax(axis=-1)
84-
return ' '.join([self.dictionary[y] for y in maxy[0]])
137+
return ' '.join([self.dictionary[y] for y in maxy[0]])
138+
139+
def loadmodel(self, prefix):
140+
pass
141+
142+
def savemodel(self, prefix):
143+
pass

test/test_sakaguchispell.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
import unittest
3+
4+
import shorttext.spell.sakaguchi as sk
5+
6+
class TestSCRNN(unittest.TestCase):
7+
def setUp(self):
8+
pass
9+
10+
def tearDown(self):
11+
pass
12+
13+
def generalproc(self, operation):
14+
corrector = sk.SCRNNSpellCorrector(operation)
15+
corrector.train('I am a nerd . Natural language processing is sosad .')
16+
self.assertEqual(corrector.correct('langudge'), 'language')
17+
18+
def test_NOISE_INSERT(self):
19+
self.generalproc('NOISE-INSERT')
20+
21+
def test_NOISE_DELETE(self):
22+
self.generalproc('NOISE-DELETE')
23+
24+
def test_NOISE_REPLACE(self):
25+
self.generalproc('NOISE-REPLACE')
26+
27+
def test_JUMBLE_WHOLE(self):
28+
self.generalproc('NOISE-WHOLE')
29+
30+
def test_JUMBLE_BEG(self):
31+
self.generalproc('JUMBLE-BEG')
32+
33+
def test_JUMBLE_END(self):
34+
self.generalproc('JUMBLE-END')
35+
36+
def test_JUMBLE_INT(self):
37+
self.generalproc('JUMBLE-INT')
38+
39+
40+
if __name__ == '__main__':
41+
unittest.main()
42+

0 commit comments

Comments
 (0)