Skip to content

Commit 895604a

Browse files
committed
Format file with PyCharm
1 parent aebcd44 commit 895604a

File tree

2 files changed

+203
-195
lines changed

2 files changed

+203
-195
lines changed

NEW_digit_recog.py

Lines changed: 106 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,61 @@
77
"""
88

99
import numpy as np
10-
#from scipy.misc.pilutil import imresize
10+
# from scipy.misc.pilutil import imresize
1111
from needed import imresize
1212
from PIL import Image
13-
import cv2 #version 3.2.0
13+
import cv2 # version 3.2.0
1414
from skimage.feature import hog
1515
from matplotlib import pyplot as plt
1616
from sklearn.model_selection import train_test_split
1717
from sklearn.metrics import accuracy_score
1818
from sklearn.utils import shuffle
1919

20-
DIGIT_WIDTH = 10
20+
DIGIT_WIDTH = 10
2121
DIGIT_HEIGHT = 20
2222
IMG_HEIGHT = 28
2323
IMG_WIDTH = 28
24-
CLASS_N = 10 # 0-9
24+
CLASS_N = 10 # 0-9
2525

26-
#This method splits the input training image into small cells (of a single digit) and uses these cells as training data.
27-
#The default training image (MNIST) is a 1000x1000 size image and each digit is of size 10x20. so we divide 1000/10 horizontally and 1000/20 vertically.
26+
27+
# This method splits the input training image into small cells (of a single digit) and uses these cells as training data.
28+
# The default training image (MNIST) is a 1000x1000 size image and each digit is of size 10x20. so we divide 1000/10 horizontally and 1000/20 vertically.
2829
def split2d(img, cell_size, flatten=True):
2930
h, w = img.shape[:2]
3031
sx, sy = cell_size
31-
cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
32+
cells = [np.hsplit(row, w // sx) for row in np.vsplit(img, h // sy)]
3233
cells = np.array(cells)
3334
if flatten:
3435
cells = cells.reshape(-1, sy, sx)
3536
return cells
3637

38+
3739
def load_digits(fn):
3840
print('loading "%s for training" ...' % fn)
3941
digits_img = cv2.imread(fn, 0)
4042
digits = split2d(digits_img, (DIGIT_WIDTH, DIGIT_HEIGHT))
4143
resized_digits = []
4244
for digit in digits:
43-
resized_digits.append(imresize(digit,(IMG_WIDTH, IMG_HEIGHT)))
44-
labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
45+
resized_digits.append(imresize(digit, (IMG_WIDTH, IMG_HEIGHT)))
46+
labels = np.repeat(np.arange(CLASS_N), len(digits) / CLASS_N)
4547
return np.array(resized_digits), labels
4648

49+
4750
def pixels_to_hog_20(img_array):
4851
hog_featuresData = []
4952
for img in img_array:
50-
fd = hog(img,
51-
orientations=10,
52-
pixels_per_cell=(5,5),
53-
cells_per_block=(1,1))
53+
fd = hog(img,
54+
orientations=10,
55+
pixels_per_cell=(5, 5),
56+
cells_per_block=(1, 1))
5457
hog_featuresData.append(fd)
5558
hog_features = np.array(hog_featuresData, 'float64')
5659
return np.float32(hog_features)
5760

58-
#define a custom model in a similar class wrapper with train and predict methods
61+
62+
# define a custom model in a similar class wrapper with train and predict methods
5963
class KNN_MODEL():
60-
def __init__(self, k = 3):
64+
def __init__(self, k=3):
6165
self.k = k
6266
self.model = cv2.ml.KNearest_create()
6367

@@ -68,11 +72,12 @@ def predict(self, samples):
6872
retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
6973
return results.ravel()
7074

75+
7176
class SVM_MODEL():
72-
def __init__(self, num_feats, C = 1, gamma = 0.1):
77+
def __init__(self, num_feats, C=1, gamma=0.1):
7378
self.model = cv2.ml.SVM_create()
7479
self.model.setType(cv2.ml.SVM_C_SVC)
75-
self.model.setKernel(cv2.ml.SVM_RBF) #SVM_LINEAR, SVM_RBF
80+
self.model.setKernel(cv2.ml.SVM_RBF) # SVM_LINEAR, SVM_RBF
7681
self.model.setC(C)
7782
self.model.setGamma(gamma)
7883
self.features = num_feats
@@ -81,147 +86,147 @@ def train(self, samples, responses):
8186
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
8287

8388
def predict(self, samples):
84-
results = self.model.predict(samples.reshape(-1,self.features))
89+
results = self.model.predict(samples.reshape(-1, self.features))
8590
return results[1].ravel()
8691

8792

8893
def get_digits(contours, hierarchy):
8994
hierarchy = hierarchy[0]
90-
bounding_rectangles = [cv2.boundingRect(ctr) for ctr in contours]
95+
bounding_rectangles = [cv2.boundingRect(ctr) for ctr in contours]
9196
final_bounding_rectangles = []
92-
#find the most common heirarchy level - that is where our digits's bounding boxes are
93-
u, indices = np.unique(hierarchy[:,-1], return_inverse=True)
97+
# find the most common heirarchy level - that is where our digits's bounding boxes are
98+
u, indices = np.unique(hierarchy[:, -1], return_inverse=True)
9499
most_common_heirarchy = u[np.argmax(np.bincount(indices))]
95-
96-
for r,hr in zip(bounding_rectangles, hierarchy):
97-
x,y,w,h = r
98-
#this could vary depending on the image you are trying to predict
99-
#we are trying to extract ONLY the rectangles with images in it (this is a very simple way to do it)
100-
#we use heirarchy to extract only the boxes that are in the same global level - to avoid digits inside other digits
101-
#ex: there could be a bounding box inside every 6,9,8 because of the loops in the number's appearence - we don't want that.
102-
#read more about it here: https://docs.opencv.org/trunk/d9/d8b/tutorial_py_contours_hierarchy.html
103-
if ((w*h)>250) and (10 <= w <= 200) and (10 <= h <= 200) and hr[3] == most_common_heirarchy:
104-
final_bounding_rectangles.append(r)
100+
101+
for r, hr in zip(bounding_rectangles, hierarchy):
102+
x, y, w, h = r
103+
# this could vary depending on the image you are trying to predict
104+
# we are trying to extract ONLY the rectangles with images in it (this is a very simple way to do it)
105+
# we use heirarchy to extract only the boxes that are in the same global level - to avoid digits inside other digits
106+
# ex: there could be a bounding box inside every 6,9,8 because of the loops in the number's appearence - we don't want that.
107+
# read more about it here: https://docs.opencv.org/trunk/d9/d8b/tutorial_py_contours_hierarchy.html
108+
if ((w * h) > 250) and (10 <= w <= 200) and (10 <= h <= 200) and hr[3] == most_common_heirarchy:
109+
final_bounding_rectangles.append(r)
105110

106111
return final_bounding_rectangles
107112

108113

109114
def proc_user_img(img_file, model):
110115
print('loading "%s for digit recognition" ...' % img_file)
111-
im = cv2.imread(img_file)
112-
blank_image = np.zeros((im.shape[0],im.shape[1],3), np.uint8)
116+
im = cv2.imread(img_file)
117+
blank_image = np.zeros((im.shape[0], im.shape[1], 3), np.uint8)
113118
blank_image.fill(255)
114119

115-
imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
120+
imgray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
116121
plt.imshow(imgray)
117-
kernel = np.ones((5,5),np.uint8)
118-
119-
ret,thresh = cv2.threshold(imgray,127,255,0)
120-
thresh = cv2.erode(thresh,kernel,iterations = 1)
121-
thresh = cv2.dilate(thresh,kernel,iterations = 1)
122-
thresh = cv2.erode(thresh,kernel,iterations = 1)
123-
124-
contours,hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
125-
126-
digits_rectangles = get_digits(contours,hierarchy) #rectangles of bounding the digits in user image
127-
122+
kernel = np.ones((5, 5), np.uint8)
123+
124+
ret, thresh = cv2.threshold(imgray, 127, 255, 0)
125+
thresh = cv2.erode(thresh, kernel, iterations=1)
126+
thresh = cv2.dilate(thresh, kernel, iterations=1)
127+
thresh = cv2.erode(thresh, kernel, iterations=1)
128+
129+
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
130+
131+
digits_rectangles = get_digits(contours, hierarchy) # rectangles of bounding the digits in user image
132+
128133
for rect in digits_rectangles:
129-
x,y,w,h = rect
130-
cv2.rectangle(im,(x,y),(x+w,y+h),(0,255,0),2)
131-
im_digit = imgray[y:y+h,x:x+w]
132-
im_digit = (255-im_digit)
133-
im_digit = imresize(im_digit,(IMG_WIDTH ,IMG_HEIGHT))
134+
x, y, w, h = rect
135+
cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 0), 2)
136+
im_digit = imgray[y:y + h, x:x + w]
137+
im_digit = (255 - im_digit)
138+
im_digit = imresize(im_digit, (IMG_WIDTH, IMG_HEIGHT))
134139

135-
hog_img_data = pixels_to_hog_20([im_digit])
140+
hog_img_data = pixels_to_hog_20([im_digit])
136141
pred = model.predict(hog_img_data)
137-
cv2.putText(im, str(int(pred[0])), (x,y),cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 3)
138-
cv2.putText(blank_image, str(int(pred[0])), (x,y),cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 0, 0), 5)
142+
cv2.putText(im, str(int(pred[0])), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 3)
143+
cv2.putText(blank_image, str(int(pred[0])), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 0, 0), 5)
139144

140145
plt.imshow(im)
141-
cv2.imwrite("original_overlay.png",im)
142-
cv2.imwrite("final_digits.png",blank_image)
143-
#cv2.destroyAllWindows()
146+
cv2.imwrite("original_overlay.png", im)
147+
cv2.imwrite("final_digits.png", blank_image)
148+
# cv2.destroyAllWindows()
144149

145150

146151
def get_contour_precedence(contour, cols):
147-
return contour[1] * cols + contour[0] #row-wise ordering
152+
return contour[1] * cols + contour[0] # row-wise ordering
148153

149154

150-
#this function processes a custom training image
151-
#see example : custom_train.digits.jpg
152-
#if you want to use your own, it should be in a similar format
155+
# this function processes a custom training image
156+
# see example : custom_train.digits.jpg
157+
# if you want to use your own, it should be in a similar format
153158
def load_digits_custom(img_file):
154159
train_data = []
155160
train_target = []
156161
start_class = 1
157162
im = cv2.imread(img_file)
158-
imgray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
163+
imgray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
159164
plt.imshow(imgray)
160-
kernel = np.ones((5,5),np.uint8)
161-
162-
ret,thresh = cv2.threshold(imgray,127,255,0)
163-
thresh = cv2.erode(thresh,kernel,iterations = 1)
164-
thresh = cv2.dilate(thresh,kernel,iterations = 1)
165-
thresh = cv2.erode(thresh,kernel,iterations = 1)
166-
167-
contours,hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
168-
digits_rectangles = get_digits(contours,hierarchy) #rectangles of bounding the digits in user image
169-
170-
#sort rectangles accoring to x,y pos so that we can label them
171-
digits_rectangles.sort(key=lambda x:get_contour_precedence(x, im.shape[1]))
172-
173-
for index,rect in enumerate(digits_rectangles):
174-
x,y,w,h = rect
175-
cv2.rectangle(im,(x,y),(x+w,y+h),(0,255,0),2)
176-
im_digit = imgray[y:y+h,x:x+w]
177-
im_digit = (255-im_digit)
178-
179-
im_digit = imresize(im_digit,(IMG_WIDTH, IMG_HEIGHT))
165+
kernel = np.ones((5, 5), np.uint8)
166+
167+
ret, thresh = cv2.threshold(imgray, 127, 255, 0)
168+
thresh = cv2.erode(thresh, kernel, iterations=1)
169+
thresh = cv2.dilate(thresh, kernel, iterations=1)
170+
thresh = cv2.erode(thresh, kernel, iterations=1)
171+
172+
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
173+
digits_rectangles = get_digits(contours, hierarchy) # rectangles of bounding the digits in user image
174+
175+
# sort rectangles accoring to x,y pos so that we can label them
176+
digits_rectangles.sort(key=lambda x: get_contour_precedence(x, im.shape[1]))
177+
178+
for index, rect in enumerate(digits_rectangles):
179+
x, y, w, h = rect
180+
cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 0), 2)
181+
im_digit = imgray[y:y + h, x:x + w]
182+
im_digit = (255 - im_digit)
183+
184+
im_digit = imresize(im_digit, (IMG_WIDTH, IMG_HEIGHT))
180185
train_data.append(im_digit)
181-
train_target.append(start_class%10)
186+
train_target.append(start_class % 10)
182187

183-
if index>0 and (index+1) % 10 == 0:
188+
if index > 0 and (index + 1) % 10 == 0:
184189
start_class += 1
185-
cv2.imwrite("training_box_overlay.png",im)
186-
190+
cv2.imwrite("training_box_overlay.png", im)
191+
187192
return np.array(train_data), np.array(train_target)
188193

189-
#------------------data preparation--------------------------------------------
190194

191-
TRAIN_MNIST_IMG = 'digits.png'
195+
# ------------------data preparation--------------------------------------------
196+
197+
TRAIN_MNIST_IMG = 'digits.png'
192198
TRAIN_USER_IMG = 'custom_train_digits.jpg'
193199
TEST_USER_IMG = 'test_image.png'
194200

195-
#digits, labels = load_digits(TRAIN_MNIST_IMG) #original MNIST data (not good detection)
196-
digits, labels = load_digits_custom(TRAIN_USER_IMG) #my handwritten dataset (better than MNIST on my handwritten digits)
201+
# digits, labels = load_digits(TRAIN_MNIST_IMG) #original MNIST data (not good detection)
202+
digits, labels = load_digits_custom(
203+
TRAIN_USER_IMG) # my handwritten dataset (better than MNIST on my handwritten digits)
197204

198-
print('train data shape',digits.shape)
199-
print('test data shape',labels.shape)
205+
print('train data shape', digits.shape)
206+
print('test data shape', labels.shape)
200207

201208
digits, labels = shuffle(digits, labels, random_state=256)
202209
train_digits_data = pixels_to_hog_20(digits)
203210
X_train, X_test, y_train, y_test = train_test_split(train_digits_data, labels, test_size=0.33, random_state=42)
204211

205-
#------------------training and testing----------------------------------------
212+
# ------------------training and testing----------------------------------------
206213

207-
model = KNN_MODEL(k = 3)
214+
model = KNN_MODEL(k=3)
208215
model.train(X_train, y_train)
209216
preds = model.predict(X_test)
210-
print('Accuracy: ',accuracy_score(y_test, preds))
217+
print('Accuracy: ', accuracy_score(y_test, preds))
211218

212-
model = KNN_MODEL(k = 4)
219+
model = KNN_MODEL(k=4)
213220
model.train(train_digits_data, labels)
214221
proc_user_img(TEST_USER_IMG, model)
215222

216-
217-
218-
model = SVM_MODEL(num_feats = train_digits_data.shape[1])
223+
model = SVM_MODEL(num_feats=train_digits_data.shape[1])
219224
model.train(X_train, y_train)
220225
preds = model.predict(X_test)
221-
print('Accuracy: ',accuracy_score(y_test, preds))
226+
print('Accuracy: ', accuracy_score(y_test, preds))
222227

223-
model = SVM_MODEL(num_feats = train_digits_data.shape[1])
228+
model = SVM_MODEL(num_feats=train_digits_data.shape[1])
224229
model.train(train_digits_data, labels)
225230
proc_user_img(TEST_USER_IMG, model)
226231

227-
#------------------------------------------------------------------------------
232+
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)