-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
63 lines (45 loc) · 1.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
'''
This is the improved version of main_v1.py
The main improvements are:
1. Now the input is a customizable csv, instead of hard coded in the text
2. Build a customizable training function.
'''
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from pytorch_lstm_01intro.model_lstm_tagger import LSTMTagger
from pytorch_lstm_01intro.preprocess import seq_to_embedding, seqs_to_dictionary
from pytorch_lstm_01intro.train import train, test
torch.manual_seed(1)
EMBEDDING_DIM = 6
HIDDEN_DIM = 6
def main():
# read in raw data
training_data_raw = pd.read_csv("./train.csv")
# create mappings
#split texts and tags into training data.
texts = [t.split() for t in training_data_raw["text"].tolist()]
tags_list = [t.split() for t in training_data_raw["tag"].tolist()]
training_data = list(zip(texts, tags_list))
word_to_ix, tag_to_ix = seqs_to_dictionary(training_data)
print(training_data)
# Usually 32 or 64 dim. Keeping them small
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(
word_to_ix), len(tag_to_ix), is_nll_loss=True)
loss_function = nn.NLLLoss() if model.is_nll_loss else nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# print(model.parameters)
# get embeddings
# See what the scores are before training
# Note that element i,j of the output is the score for tag j for word i.
# Here we don't need to train, so the code is wrapped in torch.no_grad()
testing_data = "The dog ate the book"
print("tag_scores before training:")
test(testing_data, model, word_to_ix)
train(model, loss_function, training_data, word_to_ix, tag_to_ix, optimizer, epoch=200)
# Expect something like: 0, 1, 2, 0, 1
print("tag_scores after training:")
tag_prob = test(testing_data, model, word_to_ix)
print(tag_prob)
main()