From b152183d745d025514a4d4d03a3671cd85acbdcc Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 25 Feb 2020 06:59:53 -0800 Subject: [PATCH] Add instructions to reproduce Understanding Back-translation at Scale (#1021) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1021 Differential Revision: D20077161 Pulled By: myleott fbshipit-source-id: da7f38dbac9551f29a88be3f421f8e38d9a81133 --- examples/backtranslation/README.md | 249 ++++++++++++++++++ examples/backtranslation/deduplicate_lines.py | 41 +++ examples/backtranslation/extract_bt_data.py | 59 +++++ .../backtranslation/prepare-de-monolingual.sh | 98 +++++++ .../backtranslation/prepare-wmt18en2de.sh | 135 ++++++++++ examples/backtranslation/sacrebleu.sh | 37 +++ examples/backtranslation/tokenized_bleu.sh | 46 ++++ fairseq/tasks/translation.py | 23 +- 8 files changed, 680 insertions(+), 8 deletions(-) create mode 100644 examples/backtranslation/deduplicate_lines.py create mode 100644 examples/backtranslation/extract_bt_data.py create mode 100644 examples/backtranslation/prepare-de-monolingual.sh create mode 100644 examples/backtranslation/prepare-wmt18en2de.sh create mode 100644 examples/backtranslation/sacrebleu.sh create mode 100644 examples/backtranslation/tokenized_bleu.sh diff --git a/examples/backtranslation/README.md b/examples/backtranslation/README.md index bc32675de7..73675f1125 100644 --- a/examples/backtranslation/README.md +++ b/examples/backtranslation/README.md @@ -37,6 +37,255 @@ en2de_ensemble.translate('Hello world!') # 'Hallo Welt!' ``` +## Training your own model (WMT'18 English-German) + +The following instructions can be adapted to reproduce the models from the paper. + + +#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model + +First download and preprocess the data: +```bash +# Download and prepare the data +cd examples/backtranslation/ +bash prepare-wmt18en2de.sh +cd ../.. + +# Binarize the data +TEXT=examples/backtranslation/wmt18_en_de +fairseq-preprocess \ + --joined-dictionary \ + --source-lang en --target-lang de \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \ + --workers 20 + +# Copy the BPE code into the data-bin directory for future use +cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code +``` + +(Optionally) Train a baseline model (English-German) using just the parallel data: +```bash +CHECKPOINT_DIR=checkpoints_en_de_parallel +fairseq-train --fp16 \ + data-bin/wmt18_en_de \ + --source-lang en --target-lang de \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 30000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Average the last 10 checkpoints: +```bash +python scripts/average_checkpoints.py \ + --inputs $CHECKPOINT_DIR \ + --num-epoch-checkpoints 10 \ + --output $CHECKPOINT_DIR/checkpoint.avg10.pt +``` + +Evaluate BLEU: +```bash +# tokenized BLEU on newstest2017: +bash examples/backtranslation/tokenized_bleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152) +# compare to 29.46 in Table 1, which is also for tokenized BLEU + +# generally it's better to report (detokenized) sacrebleu though: +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287) +``` + + +#### Step 2. Back-translate monolingual German data + +Train a reverse model (German-English) to do the back-translation: +```bash +CHECKPOINT_DIR=checkpoints_de_en_parallel +fairseq-train --fp16 \ + data-bin/wmt18_en_de \ + --source-lang de --target-lang en \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 30000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Let's evaluate the back-translation (BT) model to make sure it is well trained: +```bash +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + de-en \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint_best.py +# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399) +# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868 +``` + +Next prepare the monolingual data: +```bash +# Download and prepare the monolingual data +# By default the script samples 25M monolingual sentences, which after +# deduplication should be just over 24M sentences. These are split into 25 +# shards, each with 1M sentences (except for the last shard). +cd examples/backtranslation/ +bash prepare-de-monolingual.sh +cd ../.. + +# Binarize each shard of the monolingual data +TEXT=examples/backtranslation/wmt18_de_mono +for SHARD in $(seq -f "%02g" 0 24); do \ + fairseq-preprocess \ + --only-source \ + --source-lang de --target-lang en \ + --joined-dictionary \ + --srcdict data-bin/wmt18_en_de/dict.de.txt \ + --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \ + --destdir data-bin/wmt18_de_mono/shard${SHARD} \ + --workers 20; \ + cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \ +done +``` + +Now we're ready to perform back-translation over the monolingual data. The +following command generates via sampling, but it's possible to use greedy +decoding (`--beam 1`), beam search (`--beam 5`), +top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.: +```bash +mkdir backtranslation_output +for SHARD in $(seq -f "%02g" 0 24); do \ + fairseq-generate --fp16 \ + data-bin/wmt18_de_mono/shard${SHARD} \ + --path $CHECKPOINT_DIR/checkpoint_best.pt \ + --skip-invalid-size-inputs-valid-test \ + --max-tokens 4096 \ + --sampling --beam 1 \ + > backtranslation_output/sampling.shard${SHARD}.out; \ +done +``` + +After BT, use the `extract_bt_data.py` script to re-combine the shards, extract +the back-translations and apply length ratio filters: +```bash +python examples/backtranslation/extract_bt_data.py \ + --minlen 1 --maxlen 250 --ratio 1.5 \ + --output backtranslation_output/bt_data --srclang en --tgtlang de \ + backtranslation_output/sampling.shard*.out + +# Ensure lengths are the same: +# wc -l backtranslation_output/bt_data.{en,de} +# 21795614 backtranslation_output/bt_data.en +# 21795614 backtranslation_output/bt_data.de +# 43591228 total +``` + +Binarize the filtered BT data and combine it with the parallel data: +```bash +TEXT=backtranslation_output +fairseq-preprocess \ + --source-lang en --target-lang de \ + --joined-dictionary \ + --srcdict data-bin/wmt18_en_de/dict.en.txt \ + --trainpref $TEXT/bt_data \ + --destdir data-bin/wmt18_en_de_bt \ + --workers 20 + +# We want to train on the combined data, so we'll symlink the parallel + BT data +# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train" +# and the BT data as "train1", so that fairseq will combine them automatically +# and so that we can use the `--upsample-primary` option to upsample the +# parallel data (if desired). +PARA_DATA=$(readlink -f data-bin/wmt18_en_de) +BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt) +COMB_DATA=data-bin/wmt18_en_de_para_plus_bt +mkdir -p $COMB_DATA +for LANG in en de; do \ + ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \ + for EXT in bin idx; do \ + ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \ + ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \ + ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \ + ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \ + done; \ +done +``` + + +#### 3. Train an English-German model over the combined parallel + BT data + +Finally we can train a model over the parallel + BT data: +```bash +CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt +fairseq-train --fp16 \ + data-bin/wmt18_en_de_para_plus_bt \ + --upsample-primary 16 \ + --source-lang en --target-lang de \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 100000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Average the last 10 checkpoints: +```bash +python scripts/average_checkpoints.py \ + --inputs $CHECKPOINT_DIR \ + --num-epoch-checkpoints 10 \ + --output $CHECKPOINT_DIR/checkpoint.avg10.pt +``` + +Evaluate BLEU: +```bash +# tokenized BLEU on newstest2017: +bash examples/backtranslation/tokenized_bleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152) +# compare to 32.35 in Table 1, which is also for tokenized BLEU + +# generally it's better to report (detokenized) sacrebleu: +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287) +``` + + ## Citation ```bibtex @inproceedings{edunov2018backtranslation, diff --git a/examples/backtranslation/deduplicate_lines.py b/examples/backtranslation/deduplicate_lines.py new file mode 100644 index 0000000000..35a407e556 --- /dev/null +++ b/examples/backtranslation/deduplicate_lines.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fileinput +import hashlib +from multiprocessing import Pool +import sys + + +def get_hashes_and_lines(raw_line): + hash = hashlib.md5(raw_line).hexdigest() + return hash, raw_line + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--workers', type=int, default=10) + parser.add_argument('files', nargs='*', help='input files') + args = parser.parse_args() + + seen = set() + with fileinput.input(args.files, mode='rb') as h: + pool = Pool(args.workers) + results = pool.imap_unordered(get_hashes_and_lines, h, 1000) + for i, (hash, raw_line) in enumerate(results): + if hash not in seen: + seen.add(hash) + sys.stdout.buffer.write(raw_line) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + print(file=sys.stderr, flush=True) + + +if __name__ == '__main__': + main() diff --git a/examples/backtranslation/extract_bt_data.py b/examples/backtranslation/extract_bt_data.py new file mode 100644 index 0000000000..26a46942c8 --- /dev/null +++ b/examples/backtranslation/extract_bt_data.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fileinput + +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser(description=( + 'Extract back-translations from the stdout of fairseq-generate. ' + 'If there are multiply hypotheses for a source, we only keep the first one. ' + )) + parser.add_argument('--output', required=True, help='output prefix') + parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)') + parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)') + parser.add_argument('--minlen', type=int, help='min length filter') + parser.add_argument('--maxlen', type=int, help='max length filter') + parser.add_argument('--ratio', type=float, help='ratio filter') + parser.add_argument('files', nargs='*', help='input files') + args = parser.parse_args() + + def validate(src, tgt): + srclen = len(src.split(' ')) if src != '' else 0 + tgtlen = len(tgt.split(' ')) if tgt != '' else 0 + if ( + (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) + or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen)) + or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)) + ): + return False + return True + + def safe_index(toks, index, default): + try: + return toks[index] + except IndexError: + return default + + with open(args.output + '.' + args.srclang, 'w') as src_h, \ + open(args.output + '.' + args.tgtlang, 'w') as tgt_h: + for line in tqdm(fileinput.input(args.files)): + if line.startswith('S-'): + tgt = safe_index(line.rstrip().split('\t'), 1, '') + elif line.startswith('H-'): + if tgt is not None: + src = safe_index(line.rstrip().split('\t'), 2, '') + if validate(src, tgt): + print(src, file=src_h) + print(tgt, file=tgt_h) + tgt = None + + +if __name__ == '__main__': + main() diff --git a/examples/backtranslation/prepare-de-monolingual.sh b/examples/backtranslation/prepare-de-monolingual.sh new file mode 100644 index 0000000000..5e67b2b3bc --- /dev/null +++ b/examples/backtranslation/prepare-de-monolingual.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl +BPEROOT=subword-nmt/subword_nmt + + +BPE_CODE=wmt18_en_de/code +SUBSAMPLE_SIZE=25000000 +LANG=de + + +OUTDIR=wmt18_${LANG}_mono +orig=orig +tmp=$OUTDIR/tmp +mkdir -p $OUTDIR $tmp + + +URLS=( + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz" + "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz" + "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz" + "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz" + "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz" +) +FILES=( + "news.2007.de.shuffled.gz" + "news.2008.de.shuffled.gz" + "news.2009.de.shuffled.gz" + "news.2010.de.shuffled.gz" + "news.2011.de.shuffled.gz" + "news.2012.de.shuffled.gz" + "news.2013.de.shuffled.gz" + "news.2014.de.shuffled.v2.gz" + "news.2015.de.shuffled.gz" + "news.2016.de.shuffled.gz" + "news.2017.de.shuffled.deduped.gz" +) + + +cd $orig +for ((i=0;i<${#URLS[@]};++i)); do + file=${FILES[i]} + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + url=${URLS[i]} + wget "$url" + fi +done +cd .. + + +if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found monolingual sample, skipping shuffle/sample/tokenize" +else + gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \ + | shuf -n $SUBSAMPLE_SIZE \ + | perl $NORM_PUNC $LANG \ + | perl $REM_NON_PRINT_CHAR \ + | perl $TOKENIZER -threads 8 -a -l $LANG \ + > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found BPE monolingual sample, skipping BPE step" +else + python $BPEROOT/apply_bpe.py -c $BPE_CODE \ + < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \ + > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found deduplicated monolingual sample, skipping deduplication step" +else + python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \ + > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then + echo "found sharded data, skipping sharding step" +else + split --lines 1000000 --numeric-suffixes \ + --additional-suffix .${LANG} \ + $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \ + $OUTDIR/bpe.monolingual.dedup. +fi diff --git a/examples/backtranslation/prepare-wmt18en2de.sh b/examples/backtranslation/prepare-wmt18en2de.sh new file mode 100644 index 0000000000..f6fd275307 --- /dev/null +++ b/examples/backtranslation/prepare-wmt18en2de.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +echo 'Cloning Subword NMT repository (for BPE pre-processing)...' +git clone https://github.com/rsennrich/subword-nmt.git + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +CLEAN=$SCRIPTS/training/clean-corpus-n.perl +NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl +BPEROOT=subword-nmt/subword_nmt +BPE_TOKENS=32000 + +URLS=( + "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" + "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" + "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz" + "http://data.statmt.org/wmt17/translation-task/dev.tgz" + "http://statmt.org/wmt14/test-full.tgz" +) +FILES=( + "training-parallel-europarl-v7.tgz" + "training-parallel-commoncrawl.tgz" + "training-parallel-nc-v13.tgz" + "rapid2016.tgz" + "dev.tgz" + "test-full.tgz" +) +CORPORA=( + "training/europarl-v7.de-en" + "commoncrawl.de-en" + "training-parallel-nc-v13/news-commentary-v13.de-en" + "rapid2016.de-en" +) + +if [ ! -d "$SCRIPTS" ]; then + echo "Please set SCRIPTS variable correctly to point to Moses scripts." + exit 1 +fi + +OUTDIR=wmt18_en_de + +src=en +tgt=de +lang=en-de +prep=$OUTDIR +tmp=$prep/tmp +orig=orig + +mkdir -p $orig $tmp $prep + +cd $orig + +for ((i=0;i<${#URLS[@]};++i)); do + file=${FILES[i]} + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + url=${URLS[i]} + wget "$url" + if [ -f $file ]; then + echo "$url successfully downloaded." + else + echo "$url not successfully downloaded." + exit 1 + fi + if [ ${file: -4} == ".tgz" ]; then + tar zxvf $file + elif [ ${file: -4} == ".tar" ]; then + tar xvf $file + fi + fi +done +cd .. + +echo "pre-processing train data..." +for l in $src $tgt; do + rm $tmp/train.tags.$lang.tok.$l + for f in "${CORPORA[@]}"; do + cat $orig/$f.$l | \ + perl $NORM_PUNC $l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l + done +done + +echo "pre-processing test data..." +for l in $src $tgt; do + if [ "$l" == "$src" ]; then + t="src" + else + t="ref" + fi + grep '\s*//g' | \ + sed -e 's/\s*<\/seg>\s*//g' | \ + sed -e "s/\’/\'/g" | \ + perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l + echo "" +done + +echo "splitting train and valid..." +for l in $src $tgt; do + awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l + awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l +done + +TRAIN=$tmp/train.de-en +BPE_CODE=$prep/code +rm -f $TRAIN +for l in $src $tgt; do + cat $tmp/train.$l >> $TRAIN +done + +echo "learn_bpe.py on ${TRAIN}..." +python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE + +for L in $src $tgt; do + for f in train.$L valid.$L test.$L; do + echo "apply_bpe.py to ${f}..." + python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f + done +done + +perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 +perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 + +for L in $src $tgt; do + cp $tmp/bpe.test.$L $prep/test.$L +done diff --git a/examples/backtranslation/sacrebleu.sh b/examples/backtranslation/sacrebleu.sh new file mode 100644 index 0000000000..a70da23f48 --- /dev/null +++ b/examples/backtranslation/sacrebleu.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +if [ $# -ne 5 ]; then + echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]" + exit +fi + + +DATASET=$1 +LANGPAIR=$2 +DATABIN=$3 +BPECODE=$4 +MODEL=$5 + +SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1) +TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2) + + +BPEROOT=examples/backtranslation/subword-nmt/subword_nmt +if [ ! -e $BPEROOT ]; then + BPEROOT=subword-nmt/subword_nmt + if [ ! -e $BPEROOT ]; then + echo 'Cloning Subword NMT repository (for BPE pre-processing)...' + git clone https://github.com/rsennrich/subword-nmt.git + fi +fi + + +sacrebleu -t $DATASET -l $LANGPAIR --echo src \ +| sacremoses tokenize -a -l $SRCLANG -q \ +| python $BPEROOT/apply_bpe.py -c $BPECODE \ +| fairseq-interactive $DATABIN --path $MODEL \ + -s $SRCLANG -t $TGTLANG \ + --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \ +| grep ^H- | cut -f 3- \ +| sacremoses detokenize -l $TGTLANG -q \ +| sacrebleu -t $DATASET -l $LANGPAIR diff --git a/examples/backtranslation/tokenized_bleu.sh b/examples/backtranslation/tokenized_bleu.sh new file mode 100644 index 0000000000..1589da334a --- /dev/null +++ b/examples/backtranslation/tokenized_bleu.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +if [ $# -ne 5 ]; then + echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]" + exit +fi + + +DATASET=$1 +LANGPAIR=$2 +DATABIN=$3 +BPECODE=$4 +MODEL=$5 + +SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1) +TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2) + + +BPEROOT=examples/backtranslation/subword-nmt/subword_nmt +if [ ! -e $BPEROOT ]; then + BPEROOT=subword-nmt/subword_nmt + if [ ! -e $BPEROOT ]; then + echo 'Cloning Subword NMT repository (for BPE pre-processing)...' + git clone https://github.com/rsennrich/subword-nmt.git + fi +fi + + +TMP_REF=$(mktemp) + +sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \ +| sacremoses normalize -l $TGTLANG -q \ +| sacremoses tokenize -a -l $TGTLANG -q \ +> $TMP_REF + +sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \ +| sacremoses normalize -l $SRCLANG -q \ +| sacremoses tokenize -a -l $SRCLANG -q \ +| python $BPEROOT/apply_bpe.py -c $BPECODE \ +| python interactive.py $DATABIN --path $MODEL \ + -s $SRCLANG -t $TGTLANG \ + --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \ +| grep ^H- | cut -f 3- \ +| python score.py --ref $TMP_REF + +rm -f $TMP_REF diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index ada3ae780f..725514c447 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -73,9 +73,10 @@ def split_exists(split, src, tgt, lang, data_path): src_dict.eos(), ) src_datasets.append(src_dataset) - tgt_datasets.append( - data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) - ) + + tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) + if tgt_dataset is not None: + tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, src, tgt, len(src_datasets[-1]) @@ -84,20 +85,25 @@ def split_exists(split, src, tgt, lang, data_path): if not combine: break - assert len(src_datasets) == len(tgt_datasets) + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: - src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) - tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) - tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + if tgt_dataset is not None: + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: @@ -105,9 +111,10 @@ def split_exists(split, src, tgt, lang, data_path): if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, - tgt_dataset, tgt_dataset.sizes, tgt_dict, + tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions,