Skip to content

Commit 76aa718

Browse files
committed
generation working
1 parent 5a7135f commit 76aa718

File tree

3 files changed

+163
-112
lines changed

3 files changed

+163
-112
lines changed

final.py

Lines changed: 109 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@
179179
"""
180180
from enum import Enum, auto
181181
import graph
182+
import random
183+
import urllib.request
184+
import itertools
185+
import pickle
186+
from six import string_types
182187

183188

184189
class Tokenization(Enum):
@@ -220,7 +225,7 @@ def __init__(self, level, tokenization=None):
220225
f"be of acceptable type")
221226
# instance attributes
222227
self._level = level
223-
self._tokenization = tokenization
228+
self._mode = tokenization
224229
# initialize our Markov Chain
225230
self.chain = graph.Graph()
226231
# initialize our probability sum container to hold the total sum of
@@ -234,11 +239,11 @@ def level(self):
234239
return self._level
235240

236241
@property
237-
def tokenization(self):
242+
def mode(self):
238243
"""Getter that returns the given RandomWriter's tokenization
239244
attribute.
240245
"""
241-
return self._tokenization
246+
return self._mode
242247

243248
def add_chain(self, data):
244249
"""Function to add a new state to our Markov Chain, uses our Graph's
@@ -276,8 +281,29 @@ def add_conn(self, source, dest, token):
276281
"""
277282
# add an edge from source's vert obj to dest's vert obj
278283
self.chain[source].add_edge(self.chain[dest], token)
279-
# add 1 to the number of outgoing edges from the source destination
280-
self._incr_prob_sum(source)
284+
285+
def _choose_rand_edge(self, vertex):
286+
"""Randomly traversing the Markov Chain that we have already
287+
constructed with our train algorithm, output a single token
288+
289+
Use our random graph traversal algorithm where we choose a number
290+
between 1 the sum of the total weights, if that number is less than
291+
the state we are currently evaluating, then we have found our next
292+
path, otherwise we subtract that number from our randomly chosen
293+
number and continue to evaluate the new difference against the next
294+
state
295+
"""
296+
# index into our total probability sum taken from all of the
297+
rand_val = random.randint(0, self.prob_sums[vertex.data])
298+
# iterate over our current outgoing edges
299+
for choice in vertex.outgoing_edges:
300+
curr_edge = vertex.get_edge(choice)
301+
# current edge's weight is less than the random val, success
302+
if rand_val <= curr_edge.weight:
303+
return curr_edge
304+
# otherwise subtract that edge's weight from rand_val and continue
305+
else:
306+
rand_val -= curr_edge.weight
281307

282308
def generate(self):
283309
"""Generate tokens using the model.
@@ -291,8 +317,25 @@ def generate(self):
291317
new starting node at random and continuing.
292318
293319
"""
294-
# TODO: GENERATOR OBJ FOR ONCE WE HAVE OUR GRAPH
295-
raise NotImplementedError
320+
# randomly select a starting state until found one w/ outgoing edges
321+
state = random.choice(list(self.chain.vertices))
322+
vertex = self.chain[state]
323+
# ensure that we at least pick a starting state w/ outgoing edges
324+
while not vertex.outgoing_edges:
325+
state = random.choice(list(self.chain.vertices))
326+
vertex = self.chain[state]
327+
328+
# continue to traverse and generate output indefinitely
329+
while True:
330+
# choose an edge weighted-randomly
331+
curr_edge = self._choose_rand_edge(vertex)
332+
yield curr_edge.token
333+
# go to the next vertex, taking the edge we just yielded
334+
vertex = curr_edge.dest_vertex
335+
# handle case where vertex has no outgoing edges
336+
while not vertex.outgoing_edges:
337+
state = random.choice(list(self.chain.vertices))
338+
vertex = self.chain[state]
296339

297340
def generate_file(self, filename, amount):
298341
"""Write a file using the model.
@@ -310,8 +353,21 @@ def generate_file(self, filename, amount):
310353
311354
Make sure to open the file in the appropriate mode.
312355
"""
313-
# TODO: OUTPUT GENERATOR YIELDED RANDOMIZED TOKENS TO FILE
314-
raise NotImplementedError
356+
# open the file in byte mode if byte tokenized
357+
fi = open(filename, "wb") if self.mode is Tokenization.byte else \
358+
open(filename, "w", encoding="utf-8")
359+
# only get the first "amount" elements in our generated data
360+
for token in itertools.islice(self.generate(), amount):
361+
# make sure we correctly format our output
362+
if self.mode is Tokenization.word:
363+
fi.write(token+" ")
364+
elif self.mode is Tokenization.none or self.mode is None:
365+
fi.write(str(token)+" ")
366+
elif self.mode is Tokenization.byte:
367+
fi.write(bytes([token]))
368+
else:
369+
fi.write(str(token))
370+
fi.close()
315371

316372
def save_pickle(self, filename_or_file_object):
317373
"""Write this model out as a Python pickle.
@@ -324,7 +380,18 @@ def save_pickle(self, filename_or_file_object):
324380
in binary mode.
325381
326382
"""
327-
raise NotImplementedError
383+
# file object
384+
if hasattr(filename_or_file_object, "read"):
385+
# save this RandomWriter to a pickle
386+
pickle.dump(self, filename_or_file_object)
387+
# file name
388+
elif isinstance(filename_or_file_object, string_types):
389+
# open the file in the correct mode
390+
with open(filename_or_file_object, "wb") as fi:
391+
pickle.dump(self, fi)
392+
else:
393+
raise ValueError(f"Error: {filename_or_file_object} is not a "
394+
f"filename or file object")
328395

329396
@classmethod
330397
def load_pickle(cls, filename_or_file_object):
@@ -341,7 +408,18 @@ def load_pickle(cls, filename_or_file_object):
341408
in binary mode.
342409
343410
"""
344-
raise NotImplementedError
411+
# file object
412+
if hasattr(filename_or_file_object, "read"):
413+
# save this RandomWriter to a pickle
414+
pickle.load(filename_or_file_object)
415+
# file name
416+
elif isinstance(filename_or_file_object, string_types):
417+
# open the file in the correct mode
418+
with open(filename_or_file_object, "rb") as fi:
419+
pickle.load(fi)
420+
else:
421+
raise ValueError(f"Error: {filename_or_file_object} is not a "
422+
f"filename or file object")
345423

346424
def train_url(self, url):
347425
"""Compute the probabilities based on the data downloaded from url.
@@ -356,7 +434,21 @@ def train_url(self, url):
356434
Do not duplicate any code from train_iterable.
357435
358436
"""
359-
raise NotImplementedError
437+
# Ensure that the mode is correct
438+
if self.mode is Tokenization.none or self.mode is None:
439+
raise ValueError("Error: this type of training is only supported "
440+
"if the tokenization mode is not none")
441+
442+
# Open the url and read in the data
443+
with urllib.request.urlopen(url) as f:
444+
# if byte mode, we don't have to decode
445+
if self.mode is Tokenization.byte:
446+
data = f.read()
447+
# otherwise, make sure we decode as utf-8
448+
else:
449+
data = f.read().decode()
450+
# train the data
451+
self.train_iterable(data)
360452

361453
def _gen_tokenized_data(self, data, size):
362454
"""Helper function to generate tokens of proper length based on the
@@ -370,7 +462,7 @@ def _gen_tokenized_data(self, data, size):
370462
else if tokenization is byte then data is a bytestream
371463
372464
NOTE: code taken from Arthur Peters' windowed() function in
373-
final_tests.py, Thanks Mr. Peters!
465+
final_tests.py
374466
375467
TODO: handle k = 0 level case
376468
"""
@@ -389,41 +481,7 @@ def _gen_tokenized_data(self, data, size):
389481
window.append(elem)
390482
# if the window has reached specified size, yield the proper state
391483
if len(window) == size:
392-
# tokenize by string
393-
if self.tokenization is Tokenization.character or \
394-
self.tokenization is Tokenization.word:
395-
yield "".join(window)
396-
# tokenize by byte
397-
elif self.tokenization is Tokenization.byte:
398-
yield b"".join(window)
399-
# simply yield another iterable
400-
else:
401-
yield tuple(window)
402-
403-
def _data_type_check(self, data):
404-
"""Helper function to make sure that the data is in the correct form
405-
for this RandomWriter's Tokenization
406-
407-
If the tokenization mode is none, data must be an iterable. If
408-
the tokenization mode is character or word, then data must be
409-
a string. Finally, if the tokenization mode is byte, then data
410-
must be a bytes. If the type is wrong raise TypeError.
411-
"""
412-
# if in character or word tokenization, data must be a str
413-
if self.tokenization is Tokenization.character or self.tokenization \
414-
is Tokenization.word:
415-
return isinstance(data, str)
416-
# if in byte tokenization, data must by raw bytes
417-
elif self.tokenization is Tokenization.byte:
418-
return isinstance(data, bytes)
419-
# if in none tokenization, data must be an iterable
420-
elif self.tokenization is Tokenization.none or self.tokenization is \
421-
Tokenization.none.value:
422-
return hasattr(data, '__iter__')
423-
# something went wrong with the constructor
424-
else:
425-
raise TypeError("Error: this RandomWriter does not have a proper "
426-
"tokenization")
484+
yield tuple(window)
427485

428486
def _build_markov_chain(self, states):
429487
"""Helper function to help build the Markov Chain graph and compute
@@ -456,6 +514,8 @@ def _build_markov_chain(self, states):
456514
self.add_chain(new_state)
457515
# add a connection from the old chain to the new chain
458516
self.add_conn(old_state, new_state, new_state[-1])
517+
# update the old state's total probability sum
518+
self._incr_prob_sum(old_state)
459519
# iterate to the next state
460520
old_state = new_state
461521

@@ -466,19 +526,9 @@ def train_iterable(self, data):
466526
simpler to store it don't worry about it. For most input types
467527
you will not need to store it.
468528
"""
469-
# type check on the input data
470-
if not self._data_type_check(data):
471-
raise TypeError(f"Error: data is not in correct form for this "
472-
f"RW's mode -> {self.tokenization}")
473-
474529
# if we are in word tokenization then split data along white spaces
475-
if self.tokenization is Tokenization.word:
530+
if self.mode is Tokenization.word:
476531
states = self._gen_tokenized_data(data.split(), self.level)
477-
# if we are in byte tokenization then split data by byte
478-
elif self.tokenization is Tokenization.byte:
479-
states = self._gen_tokenized_data((data[i:i+1] for i in range(
480-
len(data))), self.level)
481-
# otherwise, iterate over data normally to create new gen of states
482532
else:
483533
states = self._gen_tokenized_data(data, self.level)
484534
# build our Markov Chain based on the generator we've constructed from

graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(self):
2626
vertex to the specific instantiation of the Vertex that
2727
encapsulates that state
2828
29-
TODO: Make sure to incorporate the empty graph case.
3029
"""
3130
# Our container of all the vertices in this Graph, this dictionary
3231
# should map data to the vertices that encapsulate the data
@@ -84,7 +83,8 @@ def print_graph(self):
8483
edge_obj = vert_obj.get_edge(edge)
8584
print(f"\t\tedge: {edge}")
8685
print(f"\t\t\t(token: {edge}, e_object: {edge_obj})")
87-
print(f"\t\t\tweight: {edge_obj}, dest: {edge_obj.dest_vertex}")
86+
print(f"\t\t\tweight: {edge_obj.weight}, dest: "
87+
f"{edge_obj.dest_vertex}")
8888

8989

9090
"""My implementation of a Vertex that will be used as containers to store
@@ -106,8 +106,6 @@ def __init__(self, data):
106106
# Our container for edges that leave this vertex which is
107107
# encapsulated by dictionary mappings of tokens to Edges
108108
self._outgoing_edges = {}
109-
# TODO: may need to add another dict for fast lookups that maps Edge
110-
# objects to probabilities
111109

112110
@property
113111
def data(self):

0 commit comments

Comments
 (0)