-
Notifications
You must be signed in to change notification settings - Fork 12
/
text_deep_ocr_bucketing_resume.py
116 lines (104 loc) · 3.53 KB
/
text_deep_ocr_bucketing_resume.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys, random
import numpy as np
import mxnet as mx
from text_lstm import lstm_unroll,bi_lstm_unroll
from io import BytesIO
import cv2, random
from text_bucketing_iter import TextIter
BATCH_SIZE = 20
def ctc_label(p):
ret = []
p1 = [0] + p
for i in range(len(p)):
c1 = p1[i]
c2 = p1[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret
def remove_blank(l):
ret = []
for i in range(len(l)):
if l[i] == 0:
break
ret.append(l[i])
return ret
def Accuracy(label, pred):
global BATCH_SIZE
hit = 0.
total = 0.
for i in range(BATCH_SIZE):
l = remove_blank(label[i])
p = []
for k in range(len(pred)/BATCH_SIZE):
p.append(np.argmax(pred[k * BATCH_SIZE + i]))
p = ctc_label(p)
#print p,l
if len(p) == len(l):
match = True
for k in range(len(p)):
if p[k] != int(l[k]):
match = False
break
if match:
hit += 1.0
total += 1.0
return hit / total
if __name__ == '__main__':
num_hidden = 256
num_lstm_layer = 2
num_epoch = 200
learning_rate = 0.01
momentum = 0
num_label = 10
contexts = [mx.gpu(0)]
def sym_gen(seq_len):
return bi_lstm_unroll(seq_len,
num_hidden=num_hidden,
num_label = num_label,dropout=0.75),('data','l0_init_c','l1_init_c','l0_init_h','l1_init_h'), ('label',)
init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
path='crop_icdar2013_train.lst'
path_test='crop_icdar2013_val.lst'
data_root='/cache/icdar2013_word'
test_root='/cache/icdar2013_word'
buckets=[4*i for i in range(1,num_label+1) ]
data_train=TextIter(path,data_root, BATCH_SIZE, init_states,num_label,buckets=buckets)
data_val=TextIter(path_test, test_root, BATCH_SIZE,init_states,num_label,buckets=buckets)
model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = 40,
context = contexts)
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
print 'begin fit'
def norm_stat(d):
return mx.nd.norm(d)/np.sqrt(d.size)
mon = mx.mon.Monitor(100, norm_stat)
prefix='model/icdar2013'
n_epoch_load=100
sym, arg_params, aux_params = \
mx.model.load_checkpoint(prefix, n_epoch_load)
model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = 40,
context = contexts)
model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.np(Accuracy),
num_epoch=200,
optimizer = 'sgd',
optimizer_params = { 'learning_rate': 0.005,
'momentum': 0.9,
'wd': 0 },
arg_params=arg_params,
aux_params=aux_params,
epoch_end_callback =mx.callback.do_checkpoint(prefix),
batch_end_callback = mx.callback.Speedometer(BATCH_SIZE, 50),
begin_epoch=100,
)