-
Notifications
You must be signed in to change notification settings - Fork 0
/
getfontdata.py
141 lines (119 loc) · 4.61 KB
/
getfontdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Get font images with labels from different samples, prep for network."""
# Standard imports
import os
import numpy as np
from skimage import io
# Torch
import torch
import torchvision
import torchvision.transforms as T
# Project
import utils
def getFontNumbers(fontdir="fonts"):
"""From all the images in fontdir, get separated images of the ints."""
font_images = os.listdir(fontdir)
out_numbers = []
for font in font_images:
# Read-in image
im = io.imread(os.path.join(fontdir, font))
im = utils.rgb2gray(im)
# Crop left/right on projection hard cut
proj = np.sum(im < 100, axis=0) == 0
limits = []
for i, val in enumerate(proj[1:-2]):
if proj[i] != val and val == proj[i+2]:
if val:
limits += [i]
else:
limits += [i-1]
# Crop top/bottom
proj = list(np.sum(im < 100, axis=1) == 0)
try:
upper = proj.index(0)
except ValueError:
upper = 0
try:
lower = proj[upper:].index(1) + upper - 1
except ValueError:
lower = len(proj) - 1
if len(limits) == 19:
limits += [im.shape[1]-1]
assert len(limits) == 20
numbers = []
for i in range(9):
numbers += [im[upper:lower+1, limits[2*i]:limits[2*i+1]+1]]
out_numbers += [numbers]
labels = np.arange(1, 10).reshape((1, 9)).repeat(len(out_numbers), axis=0)
return out_numbers, labels
def prepareFontNumbers(mnist_data, fontdir="fonts", generate=False):
def getUpLow(im, axis):
proj = list(np.sum(im > 10, axis=axis) == 0)
try:
upper = proj.index(0)
except ValueError:
upper = 0
try:
lower = proj[upper:].index(1) + upper - 1
except ValueError:
lower = len(proj) - 1
return upper, lower
# Get average position of img borders in broader picture
if generate:
dist = np.zeros(4) # up, lo, left, rght
for im in mnist_data:
upper, lower = getUpLow(im, axis=1)
left, right = getUpLow(im, axis=0)
dist += upper, lower, left, right
dist /= mnist_data.shape[0]
else:
dist = [4.8, 23.5, 6.6, 21.1]
# Convert to left-top-right-bottom padding
x, y = mnist_data.shape[1:]
padding = np.round([dist[2]-1, dist[0]-1, y-dist[3], x-dist[1]])
# Now apply to font images
fontNumbers, fontLabels = getFontNumbers(fontdir)
N = len(fontNumbers)*9
images = np.zeros((N, 28, 28), dtype=np.uint8)
labels = np.zeros(N, dtype=np.uint8)
for i, font in enumerate(fontNumbers):
for j, number in enumerate(font):
im = 255 - number
# Scale
_, b, _, d = padding
A = im.shape[0] / (x - b - d)
# First pad, then resize
im = torch.from_numpy(im[None, ...])
im = torchvision.transforms.functional.pad(
im, tuple(np.array(np.round(padding*A), dtype=int)))
im = torchvision.transforms.Resize(x)(im) # not always (28, 28)
im = torchvision.transforms.CenterCrop(x)(im) # keep aspect ratio
n = i*9 + j
images[n, ...] = im.numpy()[0, ...]
labels[n] = fontLabels[i, j]
return images, labels
class FontData(torch.utils.data.Dataset):
def __init__(self, mnist_reference, generate=False, angle_max=20,
distort_scale=.5, do_transform=False):
super(FontData).__init__()
self.images, self.labels = prepareFontNumbers(mnist_reference,
generate=generate)
if not do_transform:
return
# Data augmentation
transforms = [
torch.nn.Sequential(
T.RandomRotation(angle_max, expand=True), T.Resize(28)),
T.RandomPerspective(distortion_scale=distort_scale, p=1, fill=0),
T.GaussianBlur(5)
]
def applyTransform(images, _transform):
return _transform(torch.tensor(images[None, ...])).numpy()[0, ...]
# Recursively apply on newly transformed & old data
for transform in transforms:
self.images = np.append(
self.images, applyTransform(self.images, transform), axis=0)
self.labels = np.append(self.labels, self.labels)
def __getitem__(self, i):
return torch.tensor(self.images[None, i]), torch.tensor(self.labels[i])
def __len__(self):
return len(self.labels)