Skip to content

A simple yet strong implementation of neural machine translation in pytorch

Notifications You must be signed in to change notification settings

pcyin/pytorch_basic_nmt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

34 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A Basic PyTorch Implementation of Attentional Neural Machine Translation

This is a basic implementation of attentional neural machine translation (Bahdanau et al., 2015, Luong et al., 2015) in Pytorch. It implements the model described in Luong et al., 2015, and supports label smoothing, beam-search decoding and random sampling. With 256-dimensional LSTM hidden size, it achieves 28.13 BLEU score on the IWSLT 2014 Germen-English dataset (Ranzato et al., 2015).

This codebase is used for instructional purposes in Stanford CS224N Nautral Language Processing with Deep Learning and CMU 11-731 Machine Translation and Sequence-to-Sequence Models.

File Structure

  • nmt.py: contains the neural machine translation model and training/testing code.
  • vocab.py: a script that extracts vocabulary from training data
  • util.py: contains utility/helper functions

Example Dataset

We provide a preprocessed version of the IWSLT 2014 German-English translation task used in (Ranzato et al., 2015) [script]. To download the dataset:

wget http://www.cs.cmu.edu/~pengchey/iwslt2014_ende.zip
unzip iwslt2014_ende.zip

Running the script will extract adata/ folder which contains the IWSLT 2014 dataset. The dataset has 150K German-English training sentences. The data/ folder contains a copy of the public release of the dataset. Files with suffix *.wmixerprep are pre-processed versions of the dataset from Ranzato et al., 2015, with long sentences chopped and rared words replaced by a special <unk> token. You could use the pre-processed training files for training/developing (or come up with your own pre-processing strategy), but for testing you have to use the original version of testing files, ie., test.de-en.(de|en).

Environment

The code is written in Python 3.6 using some supporting third-party libraries. We provided a conda environment to install Python 3.6 with required libraries. Simply run

conda env create -f environment.yml

Usage

Each runnable script (nmt.py, vocab.py) is annotated using dotopt. Please refer to the source file for complete usage.

First, we extract a vocabulary file from the training data using the command:

python vocab.py \
    --train-src=data/train.de-en.de.wmixerprep \
    --train-tgt=data/train.de-en.en.wmixerprep \
    data/vocab.json

This generates a vocabulary file data/vocab.json. The script also has options to control the cutoff frequency and the size of generated vocabulary, which you may play with.

To start training and evaluation, simply run scripts/train.sh. After training and decoding, we call the official evaluation script multi-bleu.perl to compute the corpus-level BLEU score of the decoding results against the gold-standard.

License

This work is licensed under a Creative Commons Attribution 4.0 International License.

Releases

No releases published

Packages

No packages published