forked from demelin/nmt_toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate.py
119 lines (99 loc) · 5.17 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import logging
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.serialization import default_restore_location
from seq2seq import models, utils
from seq2seq.data.dictionary import Dictionary
from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
def get_args():
""" Defines generation-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
parser.add_argument('--cuda', default=False, help='Use a GPU')
parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed')
# Add data arguments
parser.add_argument('--data', default='data-bin', help='path to data directory')
parser.add_argument('--checkpoint-path', default='checkpoints/checkpoint_best.pt', help='path to the model file')
parser.add_argument('--batch-size', default=None, type=int, help='maximum number of sentences in a batch')
parser.add_argument('--output', default='model_translations.txt', type=str,
help='path to the output file destination')
parser.add_argument('--max-len', default=25, type=int, help='maximum length of generated sequence')
return parser.parse_args()
def main(args):
""" Main translation function' """
# Load arguments from checkpoint
torch.manual_seed(args.seed)
state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
utils.init_logging(args)
# Load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))
# Load dataset
test_dataset = Seq2SeqDataset(
src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
src_dict=src_dict, tgt_dict=tgt_dict)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
batch_sampler=BatchSampler(test_dataset, 9999999,
args.batch_size, 1, 0, shuffle=False,
seed=args.seed))
# Build model and criterion
model = models.build_model(args, src_dict, tgt_dict)
if args.cuda:
model = model.cuda()
model.eval()
model.load_state_dict(state_dict['model'])
logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
# Iterate over the test set
all_hyps = {}
for i, sample in enumerate(progress_bar):
with torch.no_grad():
# Compute the encoder output
encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
go_slice = \
torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
prev_words = go_slice
next_words = None
for _ in range(args.max_len):
with torch.no_grad():
# Compute the decoder output by repeatedly feeding it the decoded sentence prefix
decoder_out, _ = model.decoder(prev_words, encoder_out)
# Suppress <UNK>s
_, next_candidates = torch.topk(decoder_out, 2, dim=-1)
best_candidates = next_candidates[:, :, 0]
backoff_candidates = next_candidates[:, :, 1]
next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates)
prev_words = torch.cat([go_slice, next_words], dim=1)
# Segment into sentences
decoded_batch = next_words.numpy()
output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]
assert(len(output_sentences) == len(sample['id'].data))
# Remove padding
temp = list()
for sent in output_sentences:
first_eos = np.where(sent == tgt_dict.eos_idx)[0]
if len(first_eos) > 0:
temp.append(sent[:first_eos[0]])
else:
temp.append([])
output_sentences = temp
# Convert arrays of indices into strings of words
output_sentences = [tgt_dict.string(sent) for sent in output_sentences]
# Save translations
assert(len(output_sentences) == len(sample['id'].data))
for ii, sent in enumerate(output_sentences):
all_hyps[int(sample['id'].data[ii])] = sent
# Write to file
if args.output is not None:
with open(args.output, 'w') as out_file:
for sent_id in range(len(all_hyps.keys())):
out_file.write(all_hyps[sent_id] + '\n')
if __name__ == '__main__':
args = get_args()
main(args)