forked from bfsujason/bertalign
-
Notifications
You must be signed in to change notification settings - Fork 0
/
aligner.py
114 lines (95 loc) · 4.62 KB
/
aligner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
from bertalign import model
from bertalign.corelib import *
from bertalign.utils import *
class Bertalign:
def __init__(self,
src,
tgt,
max_align=5,
top_k=3,
win=5,
skip=-0.1,
margin=True,
len_penalty=True,
is_split=False,
):
self.max_align = max_align
self.top_k = top_k
self.win = win
self.skip = skip
self.margin = margin
self.len_penalty = len_penalty
src = src.replace("。」", "%。")
src = src.replace("!」", "#。")
src = src.replace("?」", "&。")
src = clean_text(src)
tgt = clean_text(tgt)
src_lang = detect_lang(src)
tgt_lang = detect_lang(tgt)
if is_split:
src_sents = src.splitlines()
tgt_sents = tgt.splitlines()
else:
src_sents = split_sents(src, src_lang)
tgt_sents = split_sents(tgt, tgt_lang)
src_num = len(src_sents)
tgt_num = len(tgt_sents)
src_lang = LANG.ISO[src_lang]
tgt_lang = LANG.ISO[tgt_lang]
print("Source language: {}, Number of sentences: {}".format(src_lang, src_num))
print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num))
print("Embedding source and target text using {} ...".format(model.model_name))
src_vecs, src_lens = model.transform(src_sents, max_align - 1)
tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1)
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.src_sents = src_sents
self.tgt_sents = tgt_sents
self.src_num = src_num
self.tgt_num = tgt_num
self.src_lens = src_lens
self.tgt_lens = tgt_lens
self.char_ratio = char_ratio
self.src_vecs = src_vecs
self.tgt_vecs = tgt_vecs
def align_sents(self):
print("Performing first-step alignment ...")
D, I = find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k)
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
first_w, first_path = find_first_search_path(self.src_num, self.tgt_num)
first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)
first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types)
print("Performing second-step alignment ...")
second_alignment_types = get_alignment_types(self.max_align)
second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num)
second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens,
second_w, second_path, second_alignment_types,
self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty)
second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types)
print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang))
self.result = second_alignment
def print_sents(self):
for bead in (self.result):
src_line = self._get_line(bead[0], self.src_sents)
src_line = src_line.replace("%。", "。」")
src_line = src_line.replace("#。", "!」")
src_line = src_line.replace("&。", "?」")
tgt_line = self._get_line(bead[1], self.tgt_sents)
print(src_line + "\n" + tgt_line + "\n")
def write_sents_to_file(self, output_file):
with open(output_file, 'w', encoding='utf-8') as file:
for bead in self.result:
src_line = self._get_line(bead[0], self.src_sents)
src_line = src_line.replace("%。", "。」")
src_line = src_line.replace("#。", "!」")
src_line = src_line.replace("&。", "?」")
tgt_line = self._get_line(bead[1], self.tgt_sents)
file.write(src_line + "\n" + tgt_line + "\n")
@staticmethod
def _get_line(bead, lines):
line = ''
if len(bead) > 0:
line = ' '.join(lines[bead[0]:bead[-1]+1])
return line