forked from cristinae/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add instructions to reproduce Understanding Back-translation at Scale (…
…facebookresearch#1021) Summary: Pull Request resolved: fairinternal/fairseq-py#1021 Differential Revision: D20077161 Pulled By: myleott fbshipit-source-id: da7f38dbac9551f29a88be3f421f8e38d9a81133
- Loading branch information
1 parent
2728f9b
commit b152183
Showing
8 changed files
with
680 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import fileinput | ||
import hashlib | ||
from multiprocessing import Pool | ||
import sys | ||
|
||
|
||
def get_hashes_and_lines(raw_line): | ||
hash = hashlib.md5(raw_line).hexdigest() | ||
return hash, raw_line | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--workers', type=int, default=10) | ||
parser.add_argument('files', nargs='*', help='input files') | ||
args = parser.parse_args() | ||
|
||
seen = set() | ||
with fileinput.input(args.files, mode='rb') as h: | ||
pool = Pool(args.workers) | ||
results = pool.imap_unordered(get_hashes_and_lines, h, 1000) | ||
for i, (hash, raw_line) in enumerate(results): | ||
if hash not in seen: | ||
seen.add(hash) | ||
sys.stdout.buffer.write(raw_line) | ||
if i % 1000000 == 0: | ||
print(i, file=sys.stderr, end="", flush=True) | ||
elif i % 100000 == 0: | ||
print(".", file=sys.stderr, end="", flush=True) | ||
print(file=sys.stderr, flush=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import fileinput | ||
|
||
from tqdm import tqdm | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description=( | ||
'Extract back-translations from the stdout of fairseq-generate. ' | ||
'If there are multiply hypotheses for a source, we only keep the first one. ' | ||
)) | ||
parser.add_argument('--output', required=True, help='output prefix') | ||
parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)') | ||
parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)') | ||
parser.add_argument('--minlen', type=int, help='min length filter') | ||
parser.add_argument('--maxlen', type=int, help='max length filter') | ||
parser.add_argument('--ratio', type=float, help='ratio filter') | ||
parser.add_argument('files', nargs='*', help='input files') | ||
args = parser.parse_args() | ||
|
||
def validate(src, tgt): | ||
srclen = len(src.split(' ')) if src != '' else 0 | ||
tgtlen = len(tgt.split(' ')) if tgt != '' else 0 | ||
if ( | ||
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) | ||
or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen)) | ||
or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)) | ||
): | ||
return False | ||
return True | ||
|
||
def safe_index(toks, index, default): | ||
try: | ||
return toks[index] | ||
except IndexError: | ||
return default | ||
|
||
with open(args.output + '.' + args.srclang, 'w') as src_h, \ | ||
open(args.output + '.' + args.tgtlang, 'w') as tgt_h: | ||
for line in tqdm(fileinput.input(args.files)): | ||
if line.startswith('S-'): | ||
tgt = safe_index(line.rstrip().split('\t'), 1, '') | ||
elif line.startswith('H-'): | ||
if tgt is not None: | ||
src = safe_index(line.rstrip().split('\t'), 2, '') | ||
if validate(src, tgt): | ||
print(src, file=src_h) | ||
print(tgt, file=tgt_h) | ||
tgt = None | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.