Skip to content

Commit 47786cc

Browse files
author
wabywang(王本友)
committed
new_dataloader
1 parent fa52749 commit 47786cc

File tree

6 files changed

+325
-29
lines changed

6 files changed

+325
-29
lines changed

dataHelper.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
import numpy as np
5+
import string
6+
from collections import Counter
7+
import pandas as pd
8+
from tqdm import tqdm
9+
import random
10+
import time
11+
import pickle
12+
from utils import log_time_delta
13+
from tqdm import tqdm
14+
from dataloader import Dataset
15+
16+
class Alphabet(dict):
17+
def __init__(self, start_feature_id = 1, alphabet_type="text"):
18+
self.fid = start_feature_id
19+
if alphabet_type=="text":
20+
self.add('[PADDING]')
21+
self.add('[UNK]')
22+
self.add('[END]')
23+
self.unknow_token = self.get('[UNK]')
24+
self.end_token = self.get('[END]')
25+
self.padding_token = self.get('[PADDING]')
26+
27+
def add(self, item):
28+
idx = self.get(item, None)
29+
if idx is None:
30+
idx = self.fid
31+
self[item] = idx
32+
# self[idx] = item
33+
self.fid += 1
34+
return idx
35+
36+
def addAll(self,words):
37+
for word in words:
38+
self.add(word)
39+
40+
def dump(self, fname,path="temp"):
41+
if not os.path.exists(path):
42+
os.mkdir(path)
43+
with open(os.path.join(path,fname), "w") as out:
44+
for k in sorted(self.keys()):
45+
out.write("{}\t{}\n".format(k, self[k]))
46+
47+
class BucketIterator(object):
48+
def __init__(self,data,opt=None,batch_size=2,shuffle=True):
49+
self.shuffle=shuffle
50+
self.data=data
51+
self.batch_size=batch_size
52+
if opt is not None:
53+
self.setup(opt)
54+
def setup(self,opt):
55+
self.data=opt.data
56+
self.batch_size=opt.batch_size
57+
self.shuffle=opt.__dict__.get("shuffle",self.shuffle)
58+
def __iter__(self):
59+
if self.shuffle:
60+
self.data = self.data.sample(frac=1).reset_index(drop=True)
61+
batch_nums = int(len(self.data)/self.batch_size)
62+
for i in range(batch_nums):
63+
yield self.data[i*self.batch_size:(i+1)*self.batch_size]
64+
yield self.data[-1*self.batch_size:]
65+
66+
67+
@log_time_delta
68+
def getSubVectors(vectors,vocab,dim):
69+
embedding = np.zeros((len(vocab),dim))
70+
count = 1
71+
for word in vocab:
72+
if word in vectors:
73+
count += 1
74+
embedding[vocab[word]]= vectors[word]
75+
else:
76+
embedding[vocab[word]]= np.random.uniform(-0.5,+0.5,dim)#vectors['[UNKNOW]'] #.tolist()
77+
print( 'word in embedding',count)
78+
return embedding
79+
80+
@log_time_delta
81+
def load_text_vec(alphabet,filename="",embedding_size=-1):
82+
vectors = {}
83+
with open(filename,encoding='utf-8') as f:
84+
for line in tqdm(f):
85+
items = line.strip().split(' ')
86+
if len(items) == 2:
87+
vocab_size, embedding_size= items[0],items[1]
88+
print( 'embedding_size',embedding_size)
89+
print( 'vocab_size in pretrained embedding',vocab_size)
90+
else:
91+
word = items[0]
92+
if word in alphabet:
93+
vectors[word] = items[1:]
94+
print( 'words need to be found ',len(alphabet))
95+
print( 'words found in wor2vec embedding ',len(vectors.keys()))
96+
97+
if embedding_size==-1:
98+
embedding_size = len(vectors[list(vectors.keys())[0]])
99+
return vectors,embedding_size
100+
101+
def getEmbeddingFile(name):
102+
#"glove" "w2v"
103+
104+
return "D:\dataset\glove\glove.6B.300d.txt"
105+
106+
def getDataSet(dataset):
107+
108+
data_dir = ".data/clean/demo"
109+
files=[os.path.join(data_dir,data_name) for data_name in ['train.txt','test.txt','dev.txt']]
110+
111+
112+
return files
113+
114+
115+
def loadData(opt):
116+
datas = []
117+
118+
alphabet = Alphabet(start_feature_id = 0)
119+
label_alphabet= Alphabet(start_feature_id = 0,alphabet_type="label")
120+
for filename in getDataSet(opt.dataset):
121+
df = pd.read_csv(filename,header = None,sep="\t",names=["text","label"]).fillna('0')
122+
df["text"]= df["text"].str.lower().str.split()
123+
datas.append(df)
124+
125+
df=pd.concat(datas)
126+
127+
from functools import reduce
128+
word_set=reduce(lambda x,y : set(x)|set(y),df["text"])
129+
alphabet.addAll(word_set)
130+
label_set = set(df["label"])
131+
label_alphabet.addAll(label_set)
132+
133+
if opt.max_seq_len==-1:
134+
opt.max_seq_len = df.apply(lambda row: row["text"].__len__(),axis=1).max()
135+
136+
for data in datas:
137+
data["text"]= data["text"].apply(lambda text: [alphabet.get(word,alphabet.unknow_token) for word in text] + [alphabet.padding_token] *int(opt.max_seq_len-len(text)) )
138+
data["label"]=data["label"].apply(lambda text: label_alphabet.get(text))
139+
140+
141+
glove_file = getEmbeddingFile(opt.__dict__.get("embedding","glove_6b_300"))
142+
loaded_vectors,embedding_size = load_text_vec(alphabet,glove_file)
143+
vocab = [v for k,v in alphabet.items()]
144+
vectors = getSubVectors(loaded_vectors,vocab,embedding_size)
145+
146+
opt.label_size= len(alphabet)
147+
opt.vocab_size = len(label_alphabet)
148+
opt.embedding_dim= embedding_size
149+
opt.embeddings = vectors
150+
151+
alphabet.dump(opt.dataset+".alphabet")
152+
return map(BucketIterator,datas) #map(lambda x:BucketIterator(x),datas)
153+
154+
155+
if __name__ =="__main__":
156+
import opts
157+
opt = opts.parse_opt()
158+
opt.max_seq_len=-1
159+
import dataloader
160+
dataset= dataloader.getDataset(opt)
161+
# datas=loadData(opt)
162+
163+

dataloader/Dataset.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# -*- coding: utf-8 -*-
2+
import os,urllib
3+
class Dataset(object):
4+
def __init__(self,opt=None):
5+
if opt is not None:
6+
self.setup(opt)
7+
self.root=".data_waby"
8+
self.urls=[]
9+
def setup(self,opt):
10+
# self.http_proxy='http://dev-proxy.oa.com:8080'
11+
self.name=opt.dataset
12+
self.dirname=opt.dataset
13+
14+
15+
def process(self):
16+
dirname=self.download()
17+
print("processing dirname: "+ dirname)
18+
19+
return dirname
20+
def download_from_url(self,url, path, schedule=None,http_proxy= "http://dev-proxy.oa.com:8080"):
21+
if schedule is None:
22+
schedule=lambda a,b,c : print("%.1f"%(100.0 * a * b / c), end='\r',flush=True) if (int(a * b / c)*100)%10==0 else None
23+
if http_proxy is not None:
24+
proxy = urllib.request.ProxyHandler({'http': http_proxy})
25+
# construct a new opener using your proxy settings
26+
opener = urllib.request.build_opener(proxy)
27+
# install the openen on the module-level
28+
urllib.request.install_opener(opener)
29+
urllib.request.urlretrieve(url,path,lambda a,b,c : print("%.1f"%(100.0 * a * b / c), end='\r',flush=True) if (int(a * b / c)*100)%10==0 else None )
30+
return path
31+
32+
def download(self, check=None):
33+
"""Download and unzip an online archive (.zip, .gz, or .tgz).
34+
35+
Arguments:
36+
root (str): Folder to download data to.
37+
check (str or None): Folder whose existence indicates
38+
that the dataset has already been downloaded, or
39+
None to check the existence of root/{cls.name}.
40+
41+
Returns:
42+
dataset_path (str): Path to extracted dataset.
43+
"""
44+
import zipfile,tarfile
45+
46+
path = os.path.join(self.root, self.name)
47+
check = path if check is None else check
48+
if not os.path.isdir(check):
49+
for url in self.urls:
50+
if isinstance(url, tuple):
51+
url, filename = url
52+
else:
53+
filename = os.path.basename(url)
54+
zpath = os.path.join(path, filename)
55+
if not os.path.isfile(zpath):
56+
if not os.path.exists(os.path.dirname(zpath)):
57+
os.makedirs(os.path.dirname(zpath))
58+
print('downloading {}'.format(filename))
59+
60+
self.download_from_url(url, zpath)
61+
ext = os.path.splitext(filename)[-1]
62+
if ext == '.zip':
63+
with zipfile.ZipFile(zpath, 'r') as zfile:
64+
print('extracting')
65+
zfile.extractall(path)
66+
elif ext in ['.gz', '.tgz']:
67+
with tarfile.open(zpath, 'r:gz') as tar:
68+
dirs = [member for member in tar.getmembers()]
69+
tar.extractall(path=path, members=dirs)
70+
return os.path.join(path, os.path.splitext(filename)[-2])
71+
72+
73+
74+
if __name__ =="__main__":
75+
import opts
76+
opt = opts.parse_opt()
77+
opt.max_seq_len=-1
78+
from dataloader import Dataset
79+
x=Dataset(opt)
80+
81+
x.process()
82+
# datas=loadData(opt)
83+
84+

dataloader/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from .IMDB import IMDBDataset
5+
6+
7+
def getDataset(opt):
8+
if opt.dataset=="imdb":
9+
dataset = IMDBDataset(opt)
10+
11+
else:
12+
raise Exception("dataset not supported: {}".format(opt.dataset))
13+
return dataset
14+

models/LSTM.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,29 @@ def forward(self, sentence):
4343
lstm_out, self.hidden = self.lstm(x, self.hidden)
4444
y = self.hidden2label(lstm_out[-1])
4545
return y
46-
def forward1(self, sentence):
47-
48-
return torch.zeros(sentence.size()[0], self.opt.label_size)
49-
# def __call__(self, **args):
50-
# self.forward(args)
51-
def test():
52-
53-
import numpy as np
54-
55-
word_embeddings = nn.Embedding(10000, 300)
56-
lstm = nn.LSTM(300, 100)
57-
h0 = Variable(torch.zeros(1, 128, 100))
58-
c0 = Variable(torch.zeros(1, 128, 100))
59-
hidden=(h0, c0)
60-
sentence = Variable(torch.LongTensor(np.zeros((128,30),dtype=np.int64)))
61-
embeds = word_embeddings(sentence)
62-
torch.tile(sentence)
63-
sentence.size()[0]
64-
65-
66-
67-
# x= Variable(torch.zeros(30, 128, 300))
68-
x = embeds.view(sentence.size()[1], self.batch_size, -1)
69-
embeds=embeds.permute(1,0,2)
70-
lstm_out, hidden = lstm(embeds, hidden)
71-
#
46+
# def forward1(self, sentence):
47+
#
48+
# return torch.zeros(sentence.size()[0], self.opt.label_size)
49+
## def __call__(self, **args):
50+
## self.forward(args)
51+
# def test():
52+
#
53+
# import numpy as np
54+
#
55+
# word_embeddings = nn.Embedding(10000, 300)
56+
# lstm = nn.LSTM(300, 100)
57+
# h0 = Variable(torch.zeros(1, 128, 100))
58+
# c0 = Variable(torch.zeros(1, 128, 100))
59+
# hidden=(h0, c0)
60+
# sentence = Variable(torch.LongTensor(np.zeros((128,30),dtype=np.int64)))
61+
# embeds = word_embeddings(sentence)
62+
# torch.tile(sentence)
63+
# sentence.size()[0]
64+
#
65+
#
66+
#
67+
## x= Variable(torch.zeros(30, 128, 300))
68+
# x = embeds.view(sentence.size()[1], self.batch_size, -1)
69+
# embeds=embeds.permute(1,0,2)
70+
# lstm_out, hidden = lstm(embeds, hidden)
71+
##

push.bash

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
git add *.py
22
git add models/*.py
3+
git add dataloader/*.py
34
git commit -m $1
45
git pull
56
git push

0 commit comments

Comments
 (0)