forked from jpuigcerver/PyLaia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pylaia-htr-decode-ctc
executable file
·165 lines (157 loc) · 5.61 KB
/
pylaia-htr-decode-ctc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python
from __future__ import absolute_import
import argparse
import os
import torch
import laia.common.logging as log
from laia.common.arguments import add_argument, args, add_defaults
from laia.common.arguments_types import str2bool
from laia.common.loader import ModelLoader, CheckpointLoader
from laia.data import ImageDataLoader, ImageFromListDataset
from laia.decoders import CTCGreedyDecoder
from laia.engine.feeders import ImageFeeder, ItemFeeder
from laia.experiments import Experiment
from laia.utils import SymbolsTable, ImageToTensor
if __name__ == "__main__":
add_defaults("batch_size", "gpu", "train_path", logging_level="WARNING")
add_argument(
"syms",
type=argparse.FileType("r"),
help="Symbols table mapping from strings to integers",
)
add_argument(
"img_dirs", type=str, nargs="+", help="Directory containing word images"
)
add_argument(
"img_list",
type=argparse.FileType("r"),
help="File or list containing images to decode",
)
add_argument(
"--model_filename", type=str, default="model", help="File name of the model"
)
add_argument(
"--checkpoint",
type=str,
default="experiment.ckpt.lowest-valid-cer*",
help="Name of the model checkpoint to use, can be a glob pattern",
)
add_argument(
"--source",
type=str,
default="experiment",
choices=["experiment", "model"],
help="Type of class which generated the checkpoint",
)
add_argument(
"--print_img_ids",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Print output with the associated image id",
)
add_argument(
"--print_char_segm",
type=str2bool,
nargs="?",
const=True,
default=False,
help="Print output with the corresponding character segmentation",
)
add_argument(
"--print_word_segm",
type=str2bool,
nargs="?",
const=True,
default=False,
help="Print output with the corresponding word segmentation",
)
add_argument(
"--separator",
type=str,
default=" ",
help="Use this string as the separator between the ids and the output",
)
add_argument("--join_str", type=str, help="Join the output using this")
add_argument(
"--use_letters", action="store_true", help="Print the output with letters"
)
add_argument(
"--space",
type=str,
default="",
help="Replace a given symbol with ' '. Used with --use_letters"
)
args = args()
'''
"--print_word_segm" option requires also that "--print_char_segm" option be set
'''
if args.print_word_segm: args.print_char_segm=True
syms = SymbolsTable(args.syms)
device = torch.device("cuda:{}".format(args.gpu - 1) if args.gpu else "cpu")
model = ModelLoader(
args.train_path, filename=args.model_filename, device=device
).load()
if model is None:
log.error("Could not find the model")
exit(1)
state = CheckpointLoader(device=device).load_by(
os.path.join(args.train_path, args.checkpoint)
)
model.load_state_dict(
state if args.source == "model" else Experiment.get_model_state_dict(state)
)
model = model.to(device)
model.eval()
dataset = ImageFromListDataset(
args.img_list, img_dirs=args.img_dirs, img_transform=ImageToTensor()
)
dataset_loader = ImageDataLoader(
dataset=dataset, image_channels=1, batch_size=args.batch_size, num_workers=8
)
batch_input_fn = ImageFeeder(device=device, parent_feeder=ItemFeeder("img"))
decoder = CTCGreedyDecoder()
for batch in dataset_loader:
batch_input = batch_input_fn(batch)
batch_output = model(batch_input)
batch_decode = decoder(batch_output,args.print_char_segm)
if args.print_char_segm:
img_sizes = batch_input.sizes.tolist()
batch_segm = \
list(
map(lambda x: [int(i*(x[1][1]-1)/x[0][-1]) for i in x[0]],
zip(decoder.seg,img_sizes))
)
for idx, (img_id, out) in enumerate(zip(batch["id"], batch_decode)):
if args.print_char_segm or args.print_word_segm:
out = [str(syms[val]) for val in out]
out = [
[val, batch_segm[idx][i], 0, batch_segm[idx][i+1], img_sizes[idx][0]-1]
for i, val in enumerate(out)
]
if args.print_word_segm:
out2=[]
cad=[]; ci=(0, 0)
for l in out:
#if l[0]!=args.space:
if l[0]!="<space>":
cad+=[l[0]]
else:
if cad: out2.append([''.join(cad),ci[0],ci[1],l[1],l[4]])
cad=[]; ci=(l[3],l[2])
if cad: out2.append([''.join(cad),ci[0],ci[1],l[3],l[4]])
out=out2
else:
if args.use_letters:
out = [str(syms[val]) for val in out]
if args.space:
#out = [' ' if sym == args.space else sym for sym in out]
out = [args.space if sym == "<space>" else sym for sym in out]
if args.join_str is not None:
out = args.join_str.join(str(x) for x in out)
print(
"{}{}{}".format(img_id, args.separator, out)
if args.print_img_ids
else out
)