diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..5592b2bdef --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,30 @@ +# Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch) +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +## Coding Style +We try to follow the PEP style guidelines and encourage you to as well. + +## License +By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/README.md b/README.md index d134144fbc..4feb0a4fdb 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Introduction FAIR Sequence-to-Sequence Toolkit (PyTorch) -This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122). The toolkit features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation. +This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) and features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation. ![Model](fairseq.gif) @@ -27,8 +27,9 @@ If you use the code in your paper, then please cite it as: Currently fairseq-py requires PyTorch from the GitHub repository. There are multiple ways of installing it. We suggest using [Miniconda3](https://conda.io/miniconda.html) and the following instructions. -* Install Miniconda3 from https://conda.io/miniconda.html create and activate python 3 environment. +* Install Miniconda3 from https://conda.io/miniconda.html; create and activate a Python 3 environment. +* Install PyTorch: ``` conda install gcc numpy cudnn nccl conda install magma-cuda80 -c soumith @@ -44,15 +45,15 @@ pip install -r requirements.txt NO_DISTRIBUTED=1 python setup.py install ``` - -Install fairseq by cloning the GitHub repository and by running - +* Install fairseq-py by cloning the GitHub repository and running: ``` pip install -r requirements.txt python setup.py build python setup.py develop ``` +# Quick Start + The following command-line tools are available: * `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data * `python train.py`: Train a new model on one or multiple GPUs @@ -60,9 +61,6 @@ The following command-line tools are available: * `python generate.py -i`: Translate raw text with a trained model * `python score.py`: BLEU scoring of generated translations against reference translations - -# Quick Start - ## Evaluating Pre-trained Models [TO BE ADAPTED] First, download a pre-trained model along with its vocabularies: ``` @@ -100,7 +98,7 @@ Check [below](#pre-trained-models) for a full list of pre-trained models availab ## Training a New Model ### Data Pre-processing -The fairseq source distribution contains an example pre-processing script for +The fairseq-py source distribution contains an example pre-processing script for the IWSLT 2014 German-English corpus. Pre-process and binarize the data as follows: ``` @@ -118,11 +116,10 @@ This will write binarized data that can be used for model training to `data-bin/ Use `python train.py` to train a new model. Here a few example settings that work well for the IWSLT 2014 dataset: ``` -$ mkdir -p trainings/fconv +$ mkdir -p checkpoints/fconv $ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \ --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ - --encoder-layers "[(256, 3)] * 4" --decoder-layers "[(256, 3)] * 3" \ - --encoder-embed-dim 256 --decoder-embed-dim 256 --save-dir trainings/fconv + --arch fconv_iwslt_de_en --save-dir checkpoints/fconv ``` By default, `python train.py` will use all available GPUs on your machine. @@ -135,7 +132,7 @@ You may need to use a smaller value depending on the available GPU memory on you Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python generate.py -i` **(for raw text)**: ``` $ python generate.py data-bin/iwslt14.tokenized.de-en \ - --path trainings/fconv/checkpoint_best.pt \ + --path checkpoints/fconv/checkpoint_best.pt \ --batch-size 128 --beam 5 | [de] dictionary: 35475 types | [en] dictionary: 24739 types @@ -172,9 +169,12 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \ ... | Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s) | Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%) +TODO: update scores (should be same as score.py) | BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448) -# Word-level BLEU scoring: +# Scoring with score.py: +$ grep ^H /tmp/gen.out | cut -f3- | sed 's/@@ //g' > /tmp/gen.out.sys +$ grep ^T /tmp/gen.out | cut -f2- | sed 's/@@ //g' > /tmp/gen.out.ref $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref TODO: update scores BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194) @@ -186,6 +186,6 @@ BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_le * Google group: https://groups.google.com/forum/#!forum/fairseq-users # License -fairseq is BSD-licensed. +fairseq-py is BSD-licensed. The license applies to the pre-trained models as well. We also provide an additional patent grant. diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 405ec0011f..b90765a5af 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -6,9 +6,14 @@ # can be found in the PATENTS file in the same directory. # -from .fconv import * +from . import fconv -__all__ = [ - 'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', - 'fconv_wmt_en_fr', -] + +__all__ = ['fconv'] + +arch_model_map = {} +for model in __all__: + archs = locals()[model].get_archs() + for arch in archs: + assert arch not in arch_model_map, 'Duplicate model architecture detected: {}'.format(arch) + arch_model_map[arch] = model diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py index c422ed1168..4f0fdf90b0 100644 --- a/fairseq/models/fconv.py +++ b/fairseq/models/fconv.py @@ -430,56 +430,90 @@ def backward(ctx, grad): return grad * ctx.scale, None -def fconv_iwslt_de_en(dataset, dropout, **kwargs): - encoder_convs = [(256, 3)] * 4 - decoder_convs = [(256, 3)] * 3 - return fconv(dataset, dropout, 256, encoder_convs, 256, decoder_convs, **kwargs) - - -def fconv_wmt_en_ro(dataset, dropout, **kwargs): - convs = [(512, 3)] * 20 - return fconv(dataset, dropout, 512, convs, 512, convs, **kwargs) - - -def fconv_wmt_en_de(dataset, dropout, **kwargs): - convs = [(512, 3)] * 9 # first 10 layers have 512 units - convs += [(1024, 3)] * 4 # next 3 layers have 768 units - convs += [(2048, 1)] * 2 # final 2 layers are 1x1 - return fconv(dataset, dropout, 768, convs, 768, convs, - decoder_out_embed_dim=512, - **kwargs) - - -def fconv_wmt_en_fr(dataset, dropout, **kwargs): - convs = [(512, 3)] * 6 # first 5 layers have 512 units - convs += [(768, 3)] * 4 # next 4 layers have 768 units - convs += [(1024, 3)] * 3 # next 4 layers have 1024 units - convs += [(2048, 1)] * 1 # next 1 layer is 1x1 - convs += [(4096, 1)] * 1 # final 1 layer is 1x1 - return fconv(dataset, dropout, 768, convs, 768, convs, - decoder_out_embed_dim=512, - **kwargs) - - -def fconv(dataset, dropout, encoder_embed_dim, encoder_convolutions, - decoder_embed_dim, decoder_convolutions, attention=True, - decoder_out_embed_dim=256, max_positions=1024): +def get_archs(): + return [ + 'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr', + ] + + +def _check_arch(args): + """Check that the specified architecture is valid and not ambiguous.""" + if args.arch not in get_archs(): + raise ValueError('Unknown fconv model architecture: {}'.format(args.arch)) + if args.arch != 'fconv': + # check that architecture is not ambiguous + for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers', + 'decoder_out_embed_dim']: + if hasattr(args, a): + raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch)) + + +def parse_arch(args): + _check_arch(args) + + if args.arch == 'fconv_iwslt_de_en': + args.encoder_embed_dim = 256 + args.encoder_layers = '[(256, 3)] * 4' + args.decoder_embed_dim = 256 + args.decoder_layers = '[(256, 3)] * 3' + args.decoder_out_embed_dim = 256 + elif args.arch == 'fconv_wmt_en_ro': + args.encoder_embed_dim = 512 + args.encoder_layers = '[(512, 3)] * 20' + args.decoder_embed_dim = 512 + args.decoder_layers = '[(512, 3)] * 20' + args.decoder_out_embed_dim = 512 + elif args.arch == 'fconv_wmt_en_de': + convs = '[(512, 3)] * 9' # first 9 layers have 512 units + convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units + convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions + args.encoder_embed_dim = 768 + args.encoder_layers = convs + args.decoder_embed_dim = 768 + args.decoder_layers = convs + args.decoder_out_embed_dim = 512 + elif args.arch == 'fconv_wmt_en_fr': + convs = '[(512, 3)] * 6' # first 6 layers have 512 units + convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units + convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units + convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions + convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions + args.encoder_embed_dim = 768 + args.encoder_layers = convs + args.decoder_embed_dim = 768 + args.decoder_layers = convs + args.decoder_out_embed_dim = 512 + else: + assert args.arch == 'fconv' + + # default architecture + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20') + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20') + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) + args.decoder_attention = getattr(args, 'decoder_attention', 'True') + return args + + +def build_model(args, dataset): padding_idx = dataset.dst_dict.pad() - encoder = Encoder( len(dataset.src_dict), - embed_dim=encoder_embed_dim, - convolutions=encoder_convolutions, - dropout=dropout, + embed_dim=args.encoder_embed_dim, + convolutions=eval(args.encoder_layers), + dropout=args.dropout, padding_idx=padding_idx, - max_positions=max_positions) + max_positions=args.max_positions, + ) decoder = Decoder( len(dataset.dst_dict), - embed_dim=decoder_embed_dim, - convolutions=decoder_convolutions, - out_embed_dim=decoder_out_embed_dim, - attention=attention, - dropout=dropout, + embed_dim=args.decoder_embed_dim, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_out_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, padding_idx=padding_idx, - max_positions=max_positions) + max_positions=args.max_positions, + ) return FConvModel(encoder, decoder, padding_idx) diff --git a/fairseq/options.py b/fairseq/options.py index 326ba5d4c9..ad1c382852 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -109,22 +109,35 @@ def add_generation_args(parser): def add_model_args(parser): - group = parser.add_argument_group('Model configuration') - group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', - choices=models.__all__, - help='model architecture ({})'.format(', '.join(models.__all__))) - group.add_argument('--encoder-embed-dim', default=512, type=int, metavar='N', + group = parser.add_argument_group( + 'Model configuration', + # Only include attributes which are explicitly given as command-line + # arguments or which have model-independent default values. + argument_default=argparse.SUPPRESS, + ) + + # The model architecture can be specified in several ways. + # In increasing order of priority: + # 1) model defaults (lowest priority) + # 2) --arch argument + # 3) --encoder/decoder-* arguments (highest priority) + # Note: --arch cannot be combined with --encoder/decoder-* arguments. + group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', choices=models.arch_model_map.keys(), + help='model architecture ({})'.format(', '.join(models.arch_model_map.keys()))) + group.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') - group.add_argument('--encoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR', + group.add_argument('--encoder-layers', type=str, metavar='EXPR', help='encoder layers [(dim, kernel_size), ...]') - group.add_argument('--decoder-embed-dim', default=512, type=int, metavar='N', + group.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') - group.add_argument('--decoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR', + group.add_argument('--decoder-layers', type=str, metavar='EXPR', help='decoder layers [(dim, kernel_size), ...]') - group.add_argument('--decoder-attention', default='True', type=str, metavar='EXPR', - help='decoder attention [True, ...]') - group.add_argument('--decoder-out-embed-dim', default=256, type=int, metavar='N', + group.add_argument('--decoder-out-embed-dim', type=int, metavar='N', help='decoder output embedding dimension') + group.add_argument('--decoder-attention', type=str, metavar='EXPR', + help='decoder attention [True, ...]') + + # These arguments have default values independent of the model: group.add_argument('--dropout', default=0.1, type=float, metavar='D', help='dropout probability') group.add_argument('--label-smoothing', default=0, type=float, metavar='D', diff --git a/fairseq/utils.py b/fairseq/utils.py index 43df598370..ca32ba8a5c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -14,20 +14,16 @@ from fairseq import criterions, data, models +def parse_args_and_arch(parser): + args = parser.parse_args() + args.model = models.arch_model_map[args.arch] + args = getattr(models, args.model).parse_arch(args) + return args + + def build_model(args, dataset): - if args.arch == 'fconv': - encoder_layers = eval(args.encoder_layers) - decoder_layers = eval(args.decoder_layers) - decoder_attention = eval(args.decoder_attention) - model = models.fconv( - dataset, args.dropout, args.encoder_embed_dim, encoder_layers, - args.decoder_embed_dim, decoder_layers, decoder_attention, - decoder_out_embed_dim=args.decoder_out_embed_dim, - max_positions=args.max_positions) - else: - model = models.__dict__[args.arch](dataset, args.dropout, - max_positions=args.max_positions) - return model + assert hasattr(models, args.model), 'Missing model type' + return getattr(models, args.model).build_model(args, dataset) def build_criterion(args, dataset): @@ -95,14 +91,14 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None): return epoch, batch_offset -def load_ensemble_for_inference(models, data_path): +def load_ensemble_for_inference(filenames, data_path): # load model architectures and weights states = [] - for model in models: - if not os.path.exists(model): - raise IOError('Model file not found: ' + model) + for filename in filenames: + if not os.path.exists(filename): + raise IOError('Model file not found: {}'.format(filename)) states.append( - torch.load(model, map_location=lambda s, l: default_restore_location(s, 'cpu')) + torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) ) # load dataset @@ -110,13 +106,13 @@ def load_ensemble_for_inference(models, data_path): dataset = data.load(data_path, args.source_lang, args.target_lang) # build models - models = [] + ensemble = [] for state in states: model = build_model(args, dataset) model.load_state_dict(state['model']) - models.append(model) + ensemble.append(model) - return models, dataset + return ensemble, dataset def prepare_sample(sample, volatile=False, cuda_device=None): diff --git a/train.py b/train.py index d3f2bcfb04..01a8577196 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,7 @@ def main(): options.add_checkpoint_args(parser) options.add_model_args(parser) - args = parser.parse_args() + args = utils.parse_args_and_arch(parser) print(args) if args.no_progress_bar: