-
Notifications
You must be signed in to change notification settings - Fork 1
/
translate.py
105 lines (88 loc) · 4.6 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils.data import segment, pad_and_truncate, unsegment, produce_vocabulary
from utils.model import Encoder, Decoder, Seq2Seq
from utils.translation import translate_one_sentence
from torch.optim.swa_utils import AveragedModel
from tqdm import tqdm
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--translate_sentence', action='store_true')
parser.add_argument('--translate_file', action='store_true')
parser.add_argument('--src_file', type=str)
parser.add_argument('--output_file', type=str)
parser.add_argument('--joint_model', type=str)
parser.add_argument('--joint_vocab', type=str)
parser.add_argument('--src_vocab', type=str)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--sentence', type=str)
parser.add_argument('--beams', type=int, default=1)
parser.add_argument('--msl', type=int, default=100)
parser.add_argument('--desegment', action='store_true')
parser.add_argument('--max_words', type=int, default=50)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--is_segmented', action='store_true')
parser.add_argument('--strategy', type=str, default='bfs')
parser.add_argument('--top_k', type=int, default=50)
parser.add_argument('--top_p', type=float, default=0.92)
parser.add_argument('--temperature', type=float, default=0.3)
parser.add_argument('--use_topk', action='store_true')
parser.add_argument('--pad_token', type=str, default='<pad>')
parser.add_argument('--no_cuda', action='store_true')
parser.add_argument('--use_swa', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# Build vocabulary
idx2word, word2idx, vocab_set, vocab_sz = produce_vocabulary(args.joint_vocab)
# Load model
print('Loading saved checkpoint.')
with open(args.save_dir + '/settings.bin', 'rb') as f:
hd, nl, nh, pf, dp, smsl, tmsl, tw, usw, cri = torch.load(f)
encoder = Encoder(vocab_sz, hd, nl, nh, pf, dp, smsl, fp16=False)
decoder = Decoder(vocab_sz, hd, nl, nh, pf, dp, tmsl, fp16=False)
model = Seq2Seq(encoder, decoder, word2idx[args.pad_token], word2idx[args.pad_token], tie_weights=tw).to(device)
if args.use_swa:
print("Using averaged model.")
model = AveragedModel(model)
with open(args.save_dir + '/swa_model.bin', 'rb') as f:
model.load_state_dict(torch.load(f))
else:
print("Using saved model.")
with open(args.save_dir + '/model.bin', 'rb') as f:
model.load_state_dict(torch.load(f))
model = model.eval()
# Translate
if args.translate_sentence:
print('Beginning translation.')
out, attn = translate_one_sentence(args.sentence, model, args.joint_model, args.src_vocab, idx2word, word2idx, vocab_set,
beams=args.beams, msl=args.msl, desegment=args.desegment, max_words=args.max_words, seed=args.seed,
device=device, is_segmented=args.is_segmented, strategy=args.strategy,
top_k=args.top_k, top_p=args.top_p, temperature=args.temperature, use_topk=args.use_topk)
print(out)
if args.translate_file:
print('Loading source file')
with open(args.src_file, 'r') as f:
src_sentences = [l.strip() for l in f]
print('Producing translations')
translations = []
for s in tqdm(src_sentences):
out, attn = translate_one_sentence(s, model, args.joint_model, args.src_vocab, idx2word, word2idx, vocab_set,
beams=args.beams, msl=args.msl, desegment=args.desegment, max_words=args.max_words, seed=args.seed,
device=device, is_segmented=args.is_segmented, strategy=args.strategy,
top_k=args.top_k, top_p=args.top_p, temperature=args.temperature, use_topk=args.use_topk)
if not args.desegment: out = ' '.join(out)
#print(out)
translations.append(out)
print('Writing to file')
with open(args.output_file, 'w') as f:
for line in translations:
f.write(line + '\n')
print('Done!')
if __name__ == '__main__':
main()