-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprocess.py
executable file
·621 lines (465 loc) · 16.8 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
import argparse
import torch
from collections import OrderedDict
from datetime import datetime
from lark import Lark, Token, UnexpectedToken
import util.io as io
from model.parser import LALR, filter_unary
from util.log import Logger
from util.nlp import NLP, Vocab, parse_grammar
def validate_args(args):
"""
Verifies preprocessing.py script arguments.
:param args: Script arguments.
:returns: 'True' if arguments are valid,
else 'False'.
"""
valid = True
def validate_split_args(src_split, tgt_split):
# Verify both source and target data paths
# are set for this split.
if bool(src_split[1]) ^ bool(tgt_split[1]):
args_ids = ((tgt_split[0], src_split[0])
if src_split[1] is None else
(src_split[0], tgt_split[0]))
logger['log'].log(
f'[ERR {datetime.now()}] ERROR: Invalid arguments, '
f'\'{args_ids[0]}\' requires \'{args_ids[1]}\''
)
return False
else:
return True
src_train = ('src_train', args.src_train)
tgt_train = ('tgt_train', args.tgt_train)
valid = validate_split_args(src_train, tgt_train)
src_dev = ('src_dev', args.src_dev)
tgt_dev = ('tgt_dev', args.tgt_dev)
valid = validate_split_args(src_dev, tgt_dev)
src_test = ('src_test', args.src_test)
tgt_test = ('tgt_test', args.tgt_test)
valid = validate_split_args(src_test, tgt_test)
return valid
def preprocess(args):
"""
The main preprocessing function. Generates the
encoder and decoder vocabularies and dataset
objects. Stores the following files under the
path and name passed with the 'save_data' argument:
- <save_data>.<split>.pt
- <save_data>.lang.pt
These files serve as input to the 'train.py' script.
:param args: Script arguments.
"""
# Preprocess grammar and build terminal operators.
grammar = io.grammar(args.grammar)
filtered, operators = parse_grammar(grammar)
datasets = io.data(
src_train=args.src_train,
tgt_train=args.tgt_train,
src_dev=args.src_dev,
tgt_dev=args.tgt_dev,
src_test=args.src_test,
tgt_test=args.tgt_test,
)
try:
# Generate the parser.
lark = Lark(
filtered,
keep_all_tokens=True,
parser='lalr',
start=args.start
)
except Exception as e:
logger['log'].log(
f'[ERR {datetime.now()}] ERROR: '
f'{e.args[0]}. (Wrong start rule argument?)'
)
return
if args.check:
logger['log'].log(
f'[INFO {datetime.now()}] validating datasets'
)
nlp = NLP(lark, operators)
validate_datasets(nlp, datasets)
else:
logger['log'].log(
f'[INFO {datetime.now()}] building vocabularies'
)
nlp = NLP(lark, operators)
vocab, vocab_dicts = __build_vocab(nlp, datasets)
nlp.vocab = vocab
datasets = __preprocess_datasets(nlp, vocab, datasets)
__save_data(grammar, vocab_dicts, datasets)
def __build_vocab(nlp, datasets):
"""
Generates the encoder vocabulary (natural language tokens),
decoder vocabulary (programming language tokens) and stack
vocabulary (terminal and non-terminal symbols, tokens) by
parsing each source and target example in each split.
:param nlp: nl processing and parsing utils.
:param datasets: all dataset splits with source and
target samples.
"""
inp_i2w, inp_w2i = __input_vocab(nlp, datasets)
out_i2w, out_w2i = __output_vocab(nlp, datasets)
vocab_dicts = {
'src': {
'i2w': inp_i2w,
'w2i': inp_w2i,
},
'tgt': {
'i2w': out_i2w,
'w2i': out_w2i,
}
}
src_vocab = Vocab(vocab_dicts['src'])
tgt_vocab = Vocab(vocab_dicts['tgt'])
vocab = {
'src': src_vocab,
'tgt': tgt_vocab
}
nlp.collect_tokens(vocab)
stack_i2w, stack_w2i = __stack_vocab(nlp)
vocab_dicts.update({'stack': {
'i2w': stack_i2w,
'w2i': stack_w2i
}})
stack_vocab = Vocab({
'i2w': stack_i2w,
'w2i': stack_w2i
})
op_i2w, op_w2i = __operator_vocab(nlp, tgt_vocab)
vocab_dicts.update({'operator': {
'i2w': op_i2w,
'w2i': op_w2i
}})
op_vocab = Vocab({
'i2w': op_i2w,
'w2i': op_w2i
})
vocab.update({'stack': stack_vocab})
vocab.update({'operator': op_vocab})
return vocab, vocab_dicts
def __preprocess_datasets(nlp, vocab, datasets):
"""
Extracts a set of fields from each sample pair in each
dataset split (see 'build_fields').
:param nlp: nl processing and parsing utils.
:param vocab: encoder, decoder and stack vocabularies.
:param datasets: all dataset splits with source and
target samples.
:returns: preprocessed datasets.
"""
result = {}
dataset_count = 0
for dataset_name in datasets:
dataset_count += 1
result[dataset_name] = []
dataset = datasets[dataset_name]
src = dataset['src']
tgt = dataset['tgt']
samples = zip(src, tgt)
data_len = len(tgt)
count = 0
now = datetime.now()
logger['line'].update(
f'[INFO {now}] {count:<6}/{data_len:>6} '
f'preprocessing {dataset_name}'
)
for sample in samples:
fields = __build_fields(nlp, sample, vocab)
count += 1
logger['line'].update(
f'[INFO {now}] {count:<6}/{data_len:>6} '
f'preprocessing {dataset_name}'
)
if fields is None:
# Fields were not parsable, skip sample.
continue
result[dataset_name].append(fields)
# TODO: Hack, fix logger.
if dataset_count == len(datasets):
logger['line'].close()
else:
# newline
logger['log'].log('')
return result
def validate_datasets(nlp, datasets):
"""
Checks whether there is an equal number of
source and target samples in a dataset split.
Each target sample in each split is parsed to
verify the sample is syntactically correct.
:param nlp: nl processing and parsing utils.
:param datasets: all dataset splits with source and
target samples.
"""
for dataset_name in datasets:
dataset = datasets[dataset_name]
logger['log'].log(
f'[INFO {datetime.now()}] validating dataset '
f'\'{dataset_name}\''
)
sources = dataset['src']
targets = dataset['tgt']
success = True
# Check if equal number of source samples
# and target samples in dataset.
if len(sources) != len(targets):
success = False
logger['log'].log(
f'[WARN {datetime.now()}] sample count mismatch, '
f'{len(sources)} source samples and {len(targets)} '
'target samples'
)
# Parse each target sample and verify
# it is a syntactically valid sample.
for i in range(len(targets)):
try:
nlp.lark.parse(targets[i])
except Exception:
success = False
logger['log'].log(
f'[WARN {datetime.now()}] parsing error '
f'while parsing line {i+1}'
)
if success:
logger['log'].log(
f'[INFO {datetime.now()}] \'{dataset_name}\' '
f' data OK'
)
def __save_data(grammar, vocab, datasets):
"""
Saves the following files to under the path and name
specified as 'save_data' argument.
- <save_data>.<split>.pt
- <save_data>.lang.pt
:param grammar: the raw grammar file.
:param vocab: encoder, decoder, operator and stack
vocabularies.
:param datasets: all dataset preprocessed dataset splits.
"""
lang = {
'grammar': grammar,
'start': args.start,
'vocab': vocab
}
lang_path = f'{args.save_data}.lang.pt'
torch.save(lang, lang_path)
logger['log'].log(
f'[INFO {datetime.now()}] vocab stored in '
f'\'{lang_path}\''
)
for k, v in datasets.items():
data_path = f'{args.save_data}.{k}.pt'
torch.save(v, data_path)
logger['log'].log(
f'[INFO {datetime.now()}] {k} dataset '
f'stored in \'{data_path}\''
)
def __input_vocab(nlp, datasets):
"""
Extracts the set of natural language tokens from
each source sample in each dataset split and builds
the encoder vocabulary.
:param nlp: nl processing and parsing utils.
:param datasets: all dataset splits with source and
target samples.
:returns: 'i2w' and 'w2i' dictionaries taking
indices to tokens and vice versa.
"""
vocab = OrderedDict()
# Meta-Symbols.
marks = nlp.mark.inp.values()
marks = {m: None for m in marks}
vocab.update(marks)
for dataset_name in datasets:
src = datasets[dataset_name]['src']
for sample in src:
tokens = nlp.normalize(sample)
tokens = {t: None for t in tokens}
vocab.update(tokens)
vocab = [t for t in vocab]
i2w = {i: t for i, t in enumerate(vocab)}
w2i = {t: i for i, t in enumerate(vocab)}
return i2w, w2i
def __output_vocab(nlp, datasets):
"""
Extracts the set of source code tokens from each target
sample in each dataset split and builds the decoder
vocabulary.
:param nlp: nl processing and parsing utils.
:param datasets: all dataset splits with source and
target samples.
:returns: 'i2w' and 'w2i' dictionaries taking
indices to tokens and vice versa.
"""
vocab = OrderedDict()
# Meta-Symbols.
marks = nlp.mark.out.values()
marks = {repr(m): None for m in marks}
vocab.update(marks)
def op_repr(op):
token = Token(op.name, f'<{op.type}>')
return repr(token)
for terminal in nlp.TERMINALS.values():
# Add all recorded tokens for each terminal.
tokens = terminal.tokens
tokens = {repr(t): None for t in tokens}
vocab.update(tokens)
for dataset_name in datasets:
tgt = datasets[dataset_name]['tgt']
for sample in tgt:
# Parse each target sample and
# update vocabulary dict with tokens.
try:
tokens = nlp.tokenize(sample)
except UnexpectedToken:
# Skip sample if not parsable.
continue
tokens = {
(repr(t) if t.type not in nlp.OPERATOR
else op_repr(nlp.OPERATOR[t.type])): None
for t in tokens
}
vocab.update(tokens)
vocab = [t for t in vocab]
i2w = {i: t for i, t in enumerate(vocab)}
w2i = {t: i for i, t in enumerate(vocab)}
return i2w, w2i
def __stack_vocab(nlp):
"""
Collects the set of symbols that can occur on
the value stack of the parser. Builds the vocabulary
for the stack encoder.
:param nlp: nl processing and parsing utils.
:returns: 'i2w' and 'w2i' dictionaries taking
indices to tokens and vice versa.
"""
vocab = OrderedDict()
marks = nlp.mark.out.values()
marks = {repr(m): None for m in marks}
vocab.update(marks)
nonterminals = nlp.NONTERMINALS.values()
symbols = {repr(nt.nt): None for nt in nonterminals}
vocab.update(symbols)
terminals = nlp.TERMINALS.values()
symbols = {repr(t): None for t in terminals}
vocab.update(symbols)
tokens = nlp.TOKENS.values()
tokens = {repr(t): None for t in tokens}
vocab.update(tokens)
i2w = {i: t for i, t in enumerate(vocab)}
w2i = {t: i for i, t in enumerate(vocab)}
return i2w, w2i
def __operator_vocab(nlp, tgt_vocab):
"""
Collects operators from target vocabulary and stores them
in a small distinct vocabulary.
:param nlp: nl processing and parsing utils.
:returns: 'i2w' and 'w2i' dictionaries taking
indices to tokens and vice versa.
"""
i2w = {}
w2i = {}
for op_name in nlp.OPERATOR:
op = nlp.OPERATOR[op_name]
t = repr(op.tokens[0])
i = tgt_vocab.w2i(t)
i2w.update({i: t})
w2i.update({t: i})
return i2w, w2i
def __build_fields(nlp, sample, vocab):
"""
Preprocesses each sample in various ways and constructs
a number of fields used during training.
:param sample: current source and target pair to be
preprocessed.
:returns: dictionary containg generated sample
fields.
"""
src = sample[0]
tgt = sample[1]
src_tokens = nlp.normalize(src, delimiters=True)
try:
tgt_tokens = nlp.tokenize(tgt)
except UnexpectedToken:
# Abort if target sample is not parsable.
return None
tgt_tokens = filter_unary(nlp, tgt_tokens)
# Create a mini sample vocab for copying.
sample_i2w = {i: t for i, t in enumerate(src_tokens)}
sample_w2i = {t: i for i, t in enumerate(src_tokens)}
sample_vocab = {'i2w': sample_i2w, 'w2i': sample_w2i}
# Create alignment vector specifying which target
# tokens should be copied from the input sequence
# and replace operator tokens in target sequence.
alignment = nlp.alignment(src_tokens, tgt_tokens, sample_vocab)
# Replace target tokens of operator type with
# respective placeholder.
for i in range(len(tgt_tokens)):
if tgt_tokens[i].type in nlp.OPERATOR:
op = nlp.OPERATOR[tgt_tokens[i].type]
tgt_tokens[i] = Token(op.name, f'<{op.type}>')
# Nonterminals in the value stack for each decoding step.
parser = LALR(nlp)
stack_seq = nlp.stack_sequence(parser.value_stack, filter_token=True)
stack_i = nlp.stack2indices(stack_seq, delimiters=True)
value_stacks = [stack_i]
for token in tgt_tokens[1:]:
parser.parse(token)
stack_seq = nlp.stack_sequence(parser.value_stack, filter_token=True)
stack_i = nlp.stack2indices(stack_seq, delimiters=True)
value_stacks.append(stack_i)
# Pre-pad value stacks.
max_stack_len = max(len(vs) for vs in value_stacks)
stack_lens = [len(vs) for vs in value_stacks]
for i in range(len(value_stacks)):
out = [0] * max_stack_len
out[:len(value_stacks[i])] = value_stacks[i]
value_stacks[i] = out
src_i = [vocab['src'].w2i(t) for t in src_tokens]
tgt_i = [vocab['tgt'].w2i(repr(t)) for t in tgt_tokens]
sample_fields = {
'src': src,
'tgt': tgt,
'src_i': src_i,
'tgt_i': tgt_i,
'sample_vocab': sample_vocab,
'alignment': alignment,
'value_stacks': value_stacks,
'stack_lens': stack_lens
}
return sample_fields
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--grammar', type=str, required=True,
help='Lark grammar for parsing target samples.')
parser.add_argument('--start', type=str, required=True,
help='The start rule of the grammar.')
parser.add_argument('--src_train', type=str, default=None,
help='Training dataset source samples.')
parser.add_argument('--tgt_train', type=str, default=None,
help='Training dataset target samples.')
parser.add_argument('--src_test', type=str, default=None,
help='Test dataset source samples.')
parser.add_argument('--tgt_test', type=str, default=None,
help='Test dataset target samples.')
parser.add_argument('--src_dev', type=str, default=None,
help='Development dataset source samples.')
parser.add_argument('--tgt_dev', type=str, default=None,
help='Development dataset target samples.')
parser.add_argument('--save_data', type=str, required=True,
help='Path and name for saving preprocessed data.')
parser.add_argument('--check', action='store_true', default=False,
help='Check dataset and parse target samples.')
args = parser.parse_args()
log = Logger()
line = log.add_text('')
log.start()
logger = {
'log': log,
'line': line
}
if validate_args(args):
preprocess(args)