Skip to content

Commit

Permalink
release Electric code
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkkev committed Nov 16, 2020
1 parent 7911132 commit f93f3f8
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 83 deletions.
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

## Introduction

**ELECTRA** is a new method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). At small scale, ELECTRA achieves strong results even when trained on a single GPU. At large scale, ELECTRA achieves state-of-the-art results on the [SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/) dataset.
**ELECTRA** is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). At small scale, ELECTRA achieves strong results even when trained on a single GPU. At large scale, ELECTRA achieves state-of-the-art results on the [SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/) dataset.

For a detailed description and experimental results, please refer to our paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/pdf?id=r1xMH1BtvB).
For a detailed description and experimental results, please refer to our ICLR 2020 paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/pdf?id=r1xMH1BtvB).

This repository contains code to pre-train ELECTRA, including small ELECTRA models on a single GPU. It also supports fine-tuning ELECTRA on downstream tasks including classification tasks (e.g,. [GLUE](https://gluebenchmark.com/)), QA tasks (e.g., [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)), and sequence tagging tasks (e.g., [text chunking](https://www.clips.uantwerpen.be/conll2000/chunking/)).

This repository also contains code for **Electric**, a version of ELECTRA inspired by [energy-based models](http://yann.lecun.com/exdb/publis/pdf/lecun-06.pdf). Electric provides a more principled view of ELECTRA as a "negative sampling" [cloze model](https://en.wikipedia.org/wiki/Cloze_test). It can also efficiently produce [pseudo-likelihood scores](https://arxiv.org/pdf/1910.14659.pdf) for text, which can be used to re-rank the outputs of speech recognition or machine translation systems. For details on Electric, please refer to out EMNLP 2020 paper [Pre-Training Transformers as Energy-Based Cloze Models](https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf).



## Released Models
Expand Down Expand Up @@ -161,6 +163,10 @@ Here are expected results for ELECTRA on various tasks (test set for chunking, d

See [here](https://github.com/google-research/electra/issues/3) for losses / training curves of the models during pre-training.

## Electric

To train [Electric](https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf), use the same pre-training script and command as ELECTRA. Pass `"electra_objective": false` and `"electric_objective": true` to the hyperparameters. We plan to release pre-trained Electric models soon!

## Citation
If you use this code for your publication, please cite the original paper:
```
Expand All @@ -173,6 +179,17 @@ If you use this code for your publication, please cite the original paper:
}
```

If you use the code for Electric, please cite the Electric paper:
```
@inproceedings{clark2020electric,
title = {Pre-Training Transformers as Energy-Based Cloze Models},
author = {Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
booktitle = {EMNLP},
year = {2020},
url = {https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf}
}
```

## Contact Info
For help or issues using ELECTRA, please submit a GitHub issue.

Expand Down
7 changes: 6 additions & 1 deletion configure_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, model_name, data_dir, **kwargs):
self.do_eval = False # evaluate generator/discriminator on unlabeled data

# loss functions
self.electra_objective = True # if False, use the BERT objective instead
# train ELECTRA or Electric? if both are false, trains a masked LM like BERT
self.electra_objective = True
self.electric_objective = False
self.gen_weight = 1.0 # masked language modeling / generator loss
self.disc_weight = 50.0 # discriminator loss
self.mask_prob = 0.15 # percent of input tokens to mask out / replace
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(self, model_name, data_dir, **kwargs):

# generator settings
self.uniform_generator = False # generator is uniform at random
self.two_tower_generator = False # generator is a two-tower cloze model
self.untied_generator_embeddings = False # tie generator/discriminator
# token embeddings?
self.untied_generator = True # tie all generator/discriminator weights?
Expand Down Expand Up @@ -127,6 +130,8 @@ def __init__(self, model_name, data_dir, **kwargs):
# self.embedding_size = 1024
# self.mask_prob = 0.25
# self.train_batch_size = 2048
if self.electric_objective:
self.two_tower_generator = True # electric requires a two-tower generator

# passed-in-arguments override (for example) debug-mode defaults
self.update(kwargs)
Expand Down
14 changes: 13 additions & 1 deletion model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(self,
input_embeddings=None,
input_reprs=None,
update_embeddings=True,
untied_embeddings=False):
untied_embeddings=False,
ltr=False,
rtl=False):
"""Constructor for BertModel.
Args:
Expand Down Expand Up @@ -232,6 +234,16 @@ def __init__(self,
attention_mask = create_attention_mask_from_input_mask(
token_type_ids, input_mask)

# Add causal masking to the attention for running the transformer
# left-to-right or right-to-left
if ltr or rtl:
causal_mask = tf.ones((seq_length, seq_length))
if ltr:
causal_mask = tf.matrix_band_part(causal_mask, -1, 0)
else:
causal_mask = tf.matrix_band_part(causal_mask, 0, -1)
attention_mask *= tf.expand_dims(causal_mask, 0)

# Run the stacked transformer. Output shapes
# sequence_output: [batch_size, seq_length, hidden_size]
# pooled_output: [batch_size, hidden_size]
Expand Down
23 changes: 18 additions & 5 deletions pretrain/pretrain_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,23 @@ def scatter_update(sequence, updates, positions):
return updated_sequence, updates_mask


def _get_candidates_mask(inputs: pretrain_data.Inputs, vocab,
disallow_from_mask=None):
VOCAB_MAPPING = {}


def get_vocab(config: configure_pretraining.PretrainingConfig):
"""Memoized load of the vocab file."""
if config.vocab_file not in VOCAB_MAPPING:
vocab = tokenization.FullTokenizer(
config.vocab_file, do_lower_case=True).vocab
VOCAB_MAPPING[config.vocab_file] = vocab
return VOCAB_MAPPING[config.vocab_file]


def get_candidates_mask(config: configure_pretraining.PretrainingConfig,
inputs: pretrain_data.Inputs,
disallow_from_mask=None):
"""Returns a mask tensor of positions in the input that can be masked out."""
vocab = get_vocab(config)
ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]]
candidates_mask = tf.ones_like(inputs.input_ids, tf.bool)
for ignore_id in ignore_ids:
Expand Down Expand Up @@ -152,9 +166,8 @@ def mask(config: configure_pretraining.PretrainingConfig,
B, L = modeling.get_shape_list(inputs.input_ids)

# Find indices where masking out a token is allowed
vocab = tokenization.FullTokenizer(
config.vocab_file, do_lower_case=config.do_lower_case).vocab
candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)
vocab = get_vocab(config)
candidates_mask = get_candidates_mask(config, inputs, disallow_from_mask)

# Set the number of tokens to mask out per example
num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
Expand Down
Loading

0 comments on commit f93f3f8

Please sign in to comment.