From f93f3f81cdc13435dd3e85766852d00ff3e00ab5 Mon Sep 17 00:00:00 2001 From: Kevin Clark Date: Mon, 16 Nov 2020 08:00:57 -0800 Subject: [PATCH] release Electric code --- README.md | 21 +++- configure_pretraining.py | 7 +- model/modeling.py | 14 ++- pretrain/pretrain_helpers.py | 23 +++- run_pretraining.py | 232 ++++++++++++++++++++++++----------- 5 files changed, 214 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 7faf222..f22eeea 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,14 @@ ## Introduction -**ELECTRA** is a new method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). At small scale, ELECTRA achieves strong results even when trained on a single GPU. At large scale, ELECTRA achieves state-of-the-art results on the [SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/) dataset. +**ELECTRA** is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). At small scale, ELECTRA achieves strong results even when trained on a single GPU. At large scale, ELECTRA achieves state-of-the-art results on the [SQuAD 2.0](https://rajpurkar.github.io/SQuAD-explorer/) dataset. -For a detailed description and experimental results, please refer to our paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/pdf?id=r1xMH1BtvB). +For a detailed description and experimental results, please refer to our ICLR 2020 paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/pdf?id=r1xMH1BtvB). This repository contains code to pre-train ELECTRA, including small ELECTRA models on a single GPU. It also supports fine-tuning ELECTRA on downstream tasks including classification tasks (e.g,. [GLUE](https://gluebenchmark.com/)), QA tasks (e.g., [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)), and sequence tagging tasks (e.g., [text chunking](https://www.clips.uantwerpen.be/conll2000/chunking/)). +This repository also contains code for **Electric**, a version of ELECTRA inspired by [energy-based models](http://yann.lecun.com/exdb/publis/pdf/lecun-06.pdf). Electric provides a more principled view of ELECTRA as a "negative sampling" [cloze model](https://en.wikipedia.org/wiki/Cloze_test). It can also efficiently produce [pseudo-likelihood scores](https://arxiv.org/pdf/1910.14659.pdf) for text, which can be used to re-rank the outputs of speech recognition or machine translation systems. For details on Electric, please refer to out EMNLP 2020 paper [Pre-Training Transformers as Energy-Based Cloze Models](https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf). + ## Released Models @@ -161,6 +163,10 @@ Here are expected results for ELECTRA on various tasks (test set for chunking, d See [here](https://github.com/google-research/electra/issues/3) for losses / training curves of the models during pre-training. +## Electric + +To train [Electric](https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf), use the same pre-training script and command as ELECTRA. Pass `"electra_objective": false` and `"electric_objective": true` to the hyperparameters. We plan to release pre-trained Electric models soon! + ## Citation If you use this code for your publication, please cite the original paper: ``` @@ -173,6 +179,17 @@ If you use this code for your publication, please cite the original paper: } ``` +If you use the code for Electric, please cite the Electric paper: +``` +@inproceedings{clark2020electric, + title = {Pre-Training Transformers as Energy-Based Cloze Models}, + author = {Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning}, + booktitle = {EMNLP}, + year = {2020}, + url = {https://www.aclweb.org/anthology/2020.emnlp-main.20.pdf} +} +``` + ## Contact Info For help or issues using ELECTRA, please submit a GitHub issue. diff --git a/configure_pretraining.py b/configure_pretraining.py index f576563..1a09649 100644 --- a/configure_pretraining.py +++ b/configure_pretraining.py @@ -32,7 +32,9 @@ def __init__(self, model_name, data_dir, **kwargs): self.do_eval = False # evaluate generator/discriminator on unlabeled data # loss functions - self.electra_objective = True # if False, use the BERT objective instead + # train ELECTRA or Electric? if both are false, trains a masked LM like BERT + self.electra_objective = True + self.electric_objective = False self.gen_weight = 1.0 # masked language modeling / generator loss self.disc_weight = 50.0 # discriminator loss self.mask_prob = 0.15 # percent of input tokens to mask out / replace @@ -65,6 +67,7 @@ def __init__(self, model_name, data_dir, **kwargs): # generator settings self.uniform_generator = False # generator is uniform at random + self.two_tower_generator = False # generator is a two-tower cloze model self.untied_generator_embeddings = False # tie generator/discriminator # token embeddings? self.untied_generator = True # tie all generator/discriminator weights? @@ -127,6 +130,8 @@ def __init__(self, model_name, data_dir, **kwargs): # self.embedding_size = 1024 # self.mask_prob = 0.25 # self.train_batch_size = 2048 + if self.electric_objective: + self.two_tower_generator = True # electric requires a two-tower generator # passed-in-arguments override (for example) debug-mode defaults self.update(kwargs) diff --git a/model/modeling.py b/model/modeling.py index f2a2030..b5f6704 100644 --- a/model/modeling.py +++ b/model/modeling.py @@ -146,7 +146,9 @@ def __init__(self, input_embeddings=None, input_reprs=None, update_embeddings=True, - untied_embeddings=False): + untied_embeddings=False, + ltr=False, + rtl=False): """Constructor for BertModel. Args: @@ -232,6 +234,16 @@ def __init__(self, attention_mask = create_attention_mask_from_input_mask( token_type_ids, input_mask) + # Add causal masking to the attention for running the transformer + # left-to-right or right-to-left + if ltr or rtl: + causal_mask = tf.ones((seq_length, seq_length)) + if ltr: + causal_mask = tf.matrix_band_part(causal_mask, -1, 0) + else: + causal_mask = tf.matrix_band_part(causal_mask, 0, -1) + attention_mask *= tf.expand_dims(causal_mask, 0) + # Run the stacked transformer. Output shapes # sequence_output: [batch_size, seq_length, hidden_size] # pooled_output: [batch_size, hidden_size] diff --git a/pretrain/pretrain_helpers.py b/pretrain/pretrain_helpers.py index 7329bec..74c6de3 100644 --- a/pretrain/pretrain_helpers.py +++ b/pretrain/pretrain_helpers.py @@ -115,9 +115,23 @@ def scatter_update(sequence, updates, positions): return updated_sequence, updates_mask -def _get_candidates_mask(inputs: pretrain_data.Inputs, vocab, - disallow_from_mask=None): +VOCAB_MAPPING = {} + + +def get_vocab(config: configure_pretraining.PretrainingConfig): + """Memoized load of the vocab file.""" + if config.vocab_file not in VOCAB_MAPPING: + vocab = tokenization.FullTokenizer( + config.vocab_file, do_lower_case=True).vocab + VOCAB_MAPPING[config.vocab_file] = vocab + return VOCAB_MAPPING[config.vocab_file] + + +def get_candidates_mask(config: configure_pretraining.PretrainingConfig, + inputs: pretrain_data.Inputs, + disallow_from_mask=None): """Returns a mask tensor of positions in the input that can be masked out.""" + vocab = get_vocab(config) ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]] candidates_mask = tf.ones_like(inputs.input_ids, tf.bool) for ignore_id in ignore_ids: @@ -152,9 +166,8 @@ def mask(config: configure_pretraining.PretrainingConfig, B, L = modeling.get_shape_list(inputs.input_ids) # Find indices where masking out a token is allowed - vocab = tokenization.FullTokenizer( - config.vocab_file, do_lower_case=config.do_lower_case).vocab - candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) + vocab = get_vocab(config) + candidates_mask = get_candidates_mask(config, inputs, disallow_from_mask) # Set the number of tokens to mask out per example num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) diff --git a/run_pretraining.py b/run_pretraining.py index 599c9f9..8859458 100644 --- a/run_pretraining.py +++ b/run_pretraining.py @@ -49,40 +49,62 @@ def __init__(self, config: configure_pretraining.PretrainingConfig, self._bert_config.num_attention_heads = 4 # Mask the input + unmasked_inputs = pretrain_data.features_to_inputs(features) masked_inputs = pretrain_helpers.mask( - config, pretrain_data.features_to_inputs(features), config.mask_prob) + config, unmasked_inputs, config.mask_prob) # Generator embedding_size = ( self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) + cloze_output = None if config.uniform_generator: + # simple generator sampling fakes uniformly at random mlm_output = self._get_masked_lm_output(masked_inputs, None) - elif config.electra_objective and config.untied_generator: - generator = self._build_transformer( - masked_inputs, is_training, - bert_config=get_generator_config(config, self._bert_config), - embedding_size=(None if config.untied_generator_embeddings - else embedding_size), - untied_embeddings=config.untied_generator_embeddings, - name="generator") - mlm_output = self._get_masked_lm_output(masked_inputs, generator) + elif ((config.electra_objective or config.electric_objective) + and config.untied_generator): + generator_config = get_generator_config(config, self._bert_config) + if config.two_tower_generator: + # two-tower cloze model generator used for electric + generator = TwoTowerClozeTransformer( + config, generator_config, unmasked_inputs, is_training, + embedding_size) + cloze_output = self._get_cloze_outputs(unmasked_inputs, generator) + mlm_output = get_softmax_output( + pretrain_helpers.gather_positions( + cloze_output.logits, masked_inputs.masked_lm_positions), + masked_inputs.masked_lm_ids, masked_inputs.masked_lm_weights, + self._bert_config.vocab_size) + else: + # small masked language model generator + generator = build_transformer( + config, masked_inputs, is_training, generator_config, + embedding_size=(None if config.untied_generator_embeddings + else embedding_size), + untied_embeddings=config.untied_generator_embeddings, + scope="generator") + mlm_output = self._get_masked_lm_output(masked_inputs, generator) else: - generator = self._build_transformer( - masked_inputs, is_training, embedding_size=embedding_size) + # full-sized masked language model generator if using BERT objective or if + # the generator and discriminator have tied weights + generator = build_transformer( + config, masked_inputs, is_training, self._bert_config, + embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, generator) fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) self.mlm_output = mlm_output - self.total_loss = config.gen_weight * mlm_output.loss + self.total_loss = config.gen_weight * ( + cloze_output.loss if config.two_tower_generator else mlm_output.loss) # Discriminator disc_output = None - if config.electra_objective: - discriminator = self._build_transformer( - fake_data.inputs, is_training, reuse=not config.untied_generator, - embedding_size=embedding_size) + if config.electra_objective or config.electric_objective: + discriminator = build_transformer( + config, fake_data.inputs, is_training, self._bert_config, + reuse=not config.untied_generator, embedding_size=embedding_size) disc_output = self._get_discriminator_output( - fake_data.inputs, discriminator, fake_data.is_fake_tokens) + fake_data.inputs, discriminator, fake_data.is_fake_tokens, + cloze_output) self.total_loss += config.disc_weight * disc_output.loss # Evaluation @@ -94,7 +116,7 @@ def __init__(self, config: configure_pretraining.PretrainingConfig, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } - if config.electra_objective: + if config.electra_objective or config.electric_objective: eval_fn_inputs.update({ "disc_loss": disc_output.per_example_loss, "disc_labels": disc_output.labels, @@ -117,7 +139,7 @@ def metric_fn(*args): metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) - if config.electra_objective: + if config.electra_objective or config.electric_objective: metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["sampled_tokids"], [-1]), @@ -141,7 +163,6 @@ def metric_fn(*args): def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" - masked_lm_weights = inputs.masked_lm_weights with tf.variable_scope("generator_predictions"): if self._config.uniform_generator: logits = tf.zeros(self._bert_config.vocab_size) @@ -151,43 +172,16 @@ def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: - relevant_hidden = pretrain_helpers.gather_positions( + relevant_reprs = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) - hidden = tf.layers.dense( - relevant_hidden, - units=modeling.get_shape_list(model.get_embedding_table())[-1], - activation=modeling.get_activation(self._bert_config.hidden_act), - kernel_initializer=modeling.create_initializer( - self._bert_config.initializer_range)) - hidden = modeling.layer_norm(hidden) - output_bias = tf.get_variable( - "output_bias", - shape=[self._bert_config.vocab_size], - initializer=tf.zeros_initializer()) - logits = tf.matmul(hidden, model.get_embedding_table(), - transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - - oh_labels = tf.one_hot( - inputs.masked_lm_ids, depth=self._bert_config.vocab_size, - dtype=tf.float32) - - probs = tf.nn.softmax(logits) - log_probs = tf.nn.log_softmax(logits) - label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) - - numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) - denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 - loss = numerator / denominator - preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) - - MLMOutput = collections.namedtuple( - "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) - return MLMOutput( - logits=logits, probs=probs, per_example_loss=label_log_probs, - loss=loss, preds=preds) - - def _get_discriminator_output(self, inputs, discriminator, labels): + logits = get_token_logits( + relevant_reprs, model.get_embedding_table(), self._bert_config) + return get_softmax_output( + logits, inputs.masked_lm_ids, inputs.masked_lm_weights, + self._bert_config.vocab_size) + + def _get_discriminator_output( + self, inputs, discriminator, labels, cloze_output=None): """Discriminator binary classifier.""" with tf.variable_scope("discriminator_predictions"): hidden = tf.layers.dense( @@ -197,6 +191,15 @@ def _get_discriminator_output(self, inputs, discriminator, labels): kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) logits = tf.squeeze(tf.layers.dense(hidden, units=1), -1) + if self._config.electric_objective: + log_q = tf.reduce_sum( + tf.nn.log_softmax(cloze_output.logits) * tf.one_hot( + inputs.input_ids, depth=self._bert_config.vocab_size, + dtype=tf.float32), -1) + log_q = tf.stop_gradient(log_q) + logits += log_q + logits += tf.log(self._config.mask_prob / (1 - self._config.mask_prob)) + weights = tf.cast(inputs.input_mask, tf.float32) labelsf = tf.cast(labels, tf.float32) losses = tf.nn.sigmoid_cross_entropy_with_logits( @@ -225,8 +228,11 @@ def _get_fake_data(self, inputs, mlm_logits): sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) - labels = masked * (1 - tf.cast( - tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) + if self._config.electric_objective: + labels = masked + else: + labels = masked * (1 - tf.cast( + tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple("FakedData", [ @@ -234,21 +240,99 @@ def _get_fake_data(self, inputs, mlm_logits): return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens) - def _build_transformer(self, inputs: pretrain_data.Inputs, is_training, - bert_config=None, name="electra", reuse=False, **kwargs): - """Build a transformer encoder network.""" - if bert_config is None: - bert_config = self._bert_config - with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): - return modeling.BertModel( - bert_config=bert_config, - is_training=is_training, - input_ids=inputs.input_ids, - input_mask=inputs.input_mask, - token_type_ids=inputs.segment_ids, - use_one_hot_embeddings=self._config.use_tpu, - scope=name, - **kwargs) + def _get_cloze_outputs(self, inputs: pretrain_data.Inputs, model): + """Cloze model softmax layer.""" + weights = tf.cast(pretrain_helpers.get_candidates_mask( + self._config, inputs), tf.float32) + with tf.variable_scope("cloze_predictions"): + logits = get_token_logits(model.get_sequence_output(), + model.get_embedding_table(), self._bert_config) + return get_softmax_output(logits, inputs.input_ids, weights, + self._bert_config.vocab_size) + + +def get_token_logits(input_reprs, embedding_table, bert_config): + hidden = tf.layers.dense( + input_reprs, + units=modeling.get_shape_list(embedding_table)[-1], + activation=modeling.get_activation(bert_config.hidden_act), + kernel_initializer=modeling.create_initializer( + bert_config.initializer_range)) + hidden = modeling.layer_norm(hidden) + output_bias = tf.get_variable( + "output_bias", + shape=[bert_config.vocab_size], + initializer=tf.zeros_initializer()) + logits = tf.matmul(hidden, embedding_table, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + return logits + + +def get_softmax_output(logits, targets, weights, vocab_size): + oh_labels = tf.one_hot(targets, depth=vocab_size, dtype=tf.float32) + preds = tf.argmax(logits, axis=-1, output_type=tf.int32) + probs = tf.nn.softmax(logits) + log_probs = tf.nn.log_softmax(logits) + label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) + numerator = tf.reduce_sum(weights * label_log_probs) + denominator = tf.reduce_sum(weights) + 1e-6 + loss = numerator / denominator + SoftmaxOutput = collections.namedtuple( + "SoftmaxOutput", ["logits", "probs", "loss", "per_example_loss", "preds", + "weights"]) + return SoftmaxOutput( + logits=logits, probs=probs, per_example_loss=label_log_probs, + loss=loss, preds=preds, weights=weights) + + +class TwoTowerClozeTransformer(object): + """Build a two-tower Transformer used as Electric's generator.""" + + def __init__(self, config, bert_config, inputs: pretrain_data.Inputs, + is_training, embedding_size): + ltr = build_transformer( + config, inputs, is_training, bert_config, + untied_embeddings=config.untied_generator_embeddings, + embedding_size=(None if config.untied_generator_embeddings + else embedding_size), + scope="generator_ltr", ltr=True) + rtl = build_transformer( + config, inputs, is_training, bert_config, + untied_embeddings=config.untied_generator_embeddings, + embedding_size=(None if config.untied_generator_embeddings + else embedding_size), + scope="generator_rtl", rtl=True) + ltr_reprs = ltr.get_sequence_output() + rtl_reprs = rtl.get_sequence_output() + self._sequence_output = tf.concat([roll(ltr_reprs, -1), + roll(rtl_reprs, 1)], -1) + self._embedding_table = ltr.embedding_table + + def get_sequence_output(self): + return self._sequence_output + + def get_embedding_table(self): + return self._embedding_table + + +def build_transformer(config: configure_pretraining.PretrainingConfig, + inputs: pretrain_data.Inputs, is_training, + bert_config, reuse=False, **kwargs): + """Build a transformer encoder network.""" + with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): + return modeling.BertModel( + bert_config=bert_config, + is_training=is_training, + input_ids=inputs.input_ids, + input_mask=inputs.input_mask, + token_type_ids=inputs.segment_ids, + use_one_hot_embeddings=config.use_tpu, + **kwargs) + + +def roll(arr, direction): + """Shifts embeddings in a [batch, seq_len, dim] tensor to the right/left.""" + return tf.concat([arr[:, direction:, :], arr[:, :direction, :]], axis=1) def get_generator_config(config: configure_pretraining.PretrainingConfig,