Skip to content

Commit

Permalink
Update config management
Browse files Browse the repository at this point in the history
  • Loading branch information
VHellendoorn committed Jul 11, 2020
1 parent 781fafa commit 67664c3
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 83 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ The modeling code is written in Python (3.6+) and uses Tensorflow (recommended 2

To run training, first clone the [data repository](https://github.com/google-research-datasets/great) and note its location (lets call it `*data_dir*`). Then, from the main directory of this repository, run: `python train.py *data_dir* vocab.txt config.yml`, to train the model configuration specified in `config.yml`, periodically writing checkpoints (to `models/` and evaluation results (to `log.txt`). Both output paths can be optionally set with `-m` and `-l` respectively.

To customize the model configuration, you can change both the hyper-parameters for the various model types available (transformer, GREAT, GGNN, RNN) in `config.yml`, and the overall model architecture itself under `training: model`. For instance, to train the RNN Sandwich architecture from our paper, set the RNN and GGNN layers to reasonable values (e.g. RNN to 2 layers and the GGNN's `time_steps` to \[3, 1\] layers as in the paper) and specify the model configuration: `rnn ggnn rnn ggnn rnn`.
To customize the model configuration, you can change both the hyper-parameters for the various model types available (transformer, GREAT, GGNN, RNN) in `config.yml`, and the overall model architecture itself under `model: configuration`. For instance, to train the RNN Sandwich architecture from our paper, set the RNN and GGNN layers to reasonable values (e.g. RNN to 2 layers and the GGNN's `time_steps` to \[3, 1\] layers as in the paper) and specify the model configuration: `rnn ggnn rnn ggnn rnn`.

## Status (07/09/2020)
Update: as of July 9th, 2020, the data has been [released](https://github.com/google-research-datasets/great). I reconstructed the data loading & model running setup today (and fixed some bugs in the [models](#code)) and am currently running the various benchmarks from the paper. There are probably still a few small bugs in the code, but the general setup from the paper works: just modify the architecture(s) in config.yml, especially the model description under `training: model` to any configuration as desired (e.g. `great`, `rnn ggnn rnn ggnn`, etc.).
Update: as of July 9th, 2020, the data has been [released](https://github.com/google-research-datasets/great). I reconstructed the data loading & model running setup today (and fixed some bugs in the [models](#code)) and am currently running the various benchmarks from the paper. There are probably still a few small bugs in the code, but the general setup from the paper works: just modify the architecture(s) in config.yml, especially the model description under `model: configuration` to any configuration as desired (e.g. `great`, `rnn ggnn rnn ggnn`, etc.).

## Data
The data for this project consists of up to three bugs per function for every function in the re-releasable subset of the Py150 corpus, paired with the original, non-buggy code. This data is now publicly available from [https://github.com/google-research-datasets/great](https://github.com/google-research-datasets/great).
Expand Down
40 changes: 19 additions & 21 deletions config.yml
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
ggnn:
hidden_dim: 512
time_steps: [3, 1, 3, 1]
residuals:
1: [0]
3: [0, 1]
dropout_rate: 0.1
add_type_bias: true
transformer:
hidden_dim: 512
ff_dim: 2048
num_layers: 6
attention_dim: 512
num_heads: 8
dropout_rate: 0.1
rnn:
hidden_dim: 512
num_layers: 2
dropout_rate: 0.1
model:
configuration: "great"
base:
hidden_dim: 512
dropout_rate: 0.1
num_edge_types: 24
rnn:
num_layers: 2
ggnn:
time_steps: [3, 1, 3, 1]
residuals:
1: [0]
3: [0, 1]
add_type_bias: true
transformer:
ff_dim: 2048
num_layers: 6
attention_dim: 512
num_heads: 8
data:
max_batch_size: 12500
max_buffer_size: 50
max_sequence_length: 512
num_edge_types: 24
valid_interval: 250000
max_valid_samples: 25000
max_token_length: 10
training:
max_steps: 100
print_freq: 25
learning_rate: 0.0001
model: "great"
26 changes: 13 additions & 13 deletions models/ggnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@


class GGNN(tf.keras.layers.Layer):
def __init__(self, model_config, num_edge_types, shared_embedding=None, vocab_dim=None):
def __init__(self, model_config, shared_embedding=None, vocab_dim=None):
super(GGNN, self).__init__()
self.num_edge_types = num_edge_types
# The main GGNN configuration is provided as a list of "time-steps", which describes how often each layer is repeated.
self.num_edge_types = model_config['num_edge_types']
# The main GGNN configuration is provided as a list of 'time-steps', which describes how often each layer is repeated.
# E.g., an 8-step GGNN with 4 distinct layers repeated 3 and 1 times alternatingly can represented as [3, 1, 3, 1]
self.time_steps = model_config["time_steps"]
self.time_steps = model_config['time_steps']
self.num_layers = len(self.time_steps)
# The residuals index in the time-steps above offset by one (index 0 refers to the node embeddings).
# They describe short-cuts formatted as receiving layer: [sending layer] entries, e.g., {1: [0], 3: [0, 1]}
self.residuals = {str(k):v for k, v in model_config["residuals"].items()} # Keys must be strings for TF checkpointing
self.hidden_dim = model_config["hidden_dim"]
self.add_type_bias = model_config["add_type_bias"]
self.dropout_rate = model_config["dropout_rate"]
self.residuals = {str(k):v for k, v in model_config['residuals'].items()} # Keys must be strings for TF checkpointing
self.hidden_dim = model_config['hidden_dim']
self.add_type_bias = model_config['add_type_bias']
self.dropout_rate = model_config['dropout_rate']

# Initialize embedding variable in constructor to allow reuse by other models
if shared_embedding is not None:
self.embed = shared_embedding
elif vocab_dim is None:
raise ValueError("Pass either a vocabulary dimension or an embedding Variable")
raise ValueError('Pass either a vocabulary dimension or an embedding Variable')
else:
random_init = tf.random_normal_initializer(stddev=self.hidden_dim ** -0.5)
self.embed = tf.Variable(random_init([vocab_dim, self.hidden_dim]), dtype=tf.float32)
Expand All @@ -34,8 +34,8 @@ def make_bias(name=None):
return tf.Variable(random_init([self.hidden_dim]), name=name)

# Set up type-transforms and GRUs
self.type_weights = [[make_weight("type-" + str(j) + "-" + str(i)) for i in range(self.num_edge_types)] for j in range(self.num_layers)]
self.type_biases = [[make_bias("bias-" + str(j) + "-" + str(i)) for i in range(self.num_edge_types)] for j in range(self.num_layers)]
self.type_weights = [[make_weight('type-' + str(j) + '-' + str(i)) for i in range(self.num_edge_types)] for j in range(self.num_layers)]
self.type_biases = [[make_bias('bias-' + str(j) + '-' + str(i)) for i in range(self.num_edge_types)] for j in range(self.num_layers)]
self.rnns = [tf.keras.layers.GRUCell(self.hidden_dim) for _ in range(self.num_layers)]
for ix, rnn in enumerate(self.rnns):
# Initialize the GRUs input dimension based on whether any residuals will be passed in.
Expand All @@ -45,7 +45,7 @@ def make_bias(name=None):
rnn.build(self.hidden_dim)

# Assume 'inputs' is an embedded batched sequence, 'edge_ids' is a sparse list of indices formatted as: [edge_type, batch_index, source_index, target_index].
#@tf.function(input_signature=[tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), tf.TensorSpec(shape=(None, 4), dtype=tf.int32), tf.TensorSpec(shape=(), dtype=tf.bool)])
@tf.function(input_signature=[tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), tf.TensorSpec(shape=(None, 4), dtype=tf.int32), tf.TensorSpec(shape=(), dtype=tf.bool)])
def call(self, states, edge_ids, training):
# Collect some basic details about the graphs in the batch.
edge_type_ids = tf.dynamic_partition(edge_ids[:, 1:], edge_ids[:, 0], self.num_edge_types)
Expand Down Expand Up @@ -96,5 +96,5 @@ def propagate(self, in_states, layer_no, edge_type_ids, message_sources, message
@tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])
def embed_inputs(self, inputs):
states = tf.nn.embedding_lookup(self.embed, inputs)
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], "float32"))
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], 'float32'))
return states
41 changes: 20 additions & 21 deletions models/great_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@


class AttentionLayer(tf.keras.layers.Layer):
"""
Implementation of multi-headed attention with optional edge-bias.
"""Implementation of multi-headed attention with optional edge-bias.
This class supports self-attention and key-value attention, with (non-optional) masks. If bias_dim is not None, the attention computation(s) assumes that a (sparse) bias vector is provided, formatted like: (edge_type, batch_index, key_index, query_index). Bias edge types are embedded in the same dimension as each head's attention and projected to a scalar before being inserted into the attention computation as (q + b) * k.
"""
Expand All @@ -19,13 +18,13 @@ def __init__(self, attention_dim, num_heads=None, hidden_dim=None, bias_dim=None
self.bias_dim = bias_dim

def build(self, _):
self.attn_query = self.add_weight(name='q', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer="glorot_uniform")
self.attn_keys = self.add_weight(name='k', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer="glorot_uniform")
self.attn_values = self.add_weight(name='v', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer="glorot_uniform")
self.weight_out = self.add_weight(name='o', shape=(self.num_heads, self.attention_dim_per_head, self.hidden_dim), initializer="glorot_uniform")
self.attn_query = self.add_weight(name='q', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer='glorot_uniform')
self.attn_keys = self.add_weight(name='k', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer='glorot_uniform')
self.attn_values = self.add_weight(name='v', shape=(self.hidden_dim, self.num_heads, self.attention_dim_per_head), initializer='glorot_uniform')
self.weight_out = self.add_weight(name='o', shape=(self.num_heads, self.attention_dim_per_head, self.hidden_dim), initializer='glorot_uniform')
if self.bias_dim is not None:
self.bias_embs = self.add_weight(name='e1', shape=(self.bias_dim, self.attention_dim_per_head), initializer="glorot_uniform")
self.bias_scalar = self.add_weight(name='e2', shape=(self.attention_dim_per_head, 1), initializer="glorot_uniform")
self.bias_embs = self.add_weight(name='e1', shape=(self.bias_dim, self.attention_dim_per_head), initializer='glorot_uniform')
self.bias_scalar = self.add_weight(name='e2', shape=(self.attention_dim_per_head, 1), initializer='glorot_uniform')

@tf.function(input_signature=[tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), tf.TensorSpec(shape=(None, None, None, None), dtype=tf.float32), tf.TensorSpec(shape=(None, 4), dtype=tf.int32)])
def call(self, states, key_states, masks, attention_bias):
Expand Down Expand Up @@ -71,7 +70,7 @@ def get_attention_weights(self, query, keys, masks, attention_bias):
alpha += bias

# Scale and apply mask
alpha *= tf.math.rsqrt(tf.cast(self.attention_dim_per_head, "float32"))
alpha *= tf.math.rsqrt(tf.cast(self.attention_dim_per_head, 'float32'))
alpha = alpha * masks + (1.0 - tf.math.ceil(masks)) * tf.float32.min
alpha = tf.nn.softmax(alpha)
alpha *= masks
Expand All @@ -97,26 +96,26 @@ class Transformer(tf.keras.layers.Layer):
"""Transformer language model: converts indices into hidden states through layers of multi-headed attention and feed-forward dense layers.
Augments a generic Transformer with attentional bias, if bias_dim is provided. See documentation on AttentionLayer for more details.
To generate language from the resulting states, pass the states to the "predict" function. Note that it assumes that the input vocabulary is output vocabulary (i.e., it reuses the model's embedding table).
To generate language from the resulting states, pass the states to the 'predict' function. Note that it assumes that the input vocabulary is output vocabulary (i.e., it reuses the model's embedding table).
"""
NOOP_BIAS = tf.zeros((0, 4), 'int32')

def __init__(self, model_config, bias_dim=None, shared_embedding=None, vocab_dim=None, is_encoder_decoder=False):
def __init__(self, model_config, shared_embedding=None, vocab_dim=None, is_encoder_decoder=False):
super(Transformer, self).__init__()
self.bias_dim = bias_dim
self.is_encoder_decoder = is_encoder_decoder
self.hidden_dim = model_config["hidden_dim"]
self.ff_dim = model_config["ff_dim"]
self.attention_dim = model_config["attention_dim"]
self.num_layers = model_config["num_layers"]
self.num_heads = model_config["num_heads"]
self.dropout_rate = model_config["dropout_rate"]
self.bias_dim = model_config['num_edge_types']
self.hidden_dim = model_config['hidden_dim']
self.ff_dim = model_config['ff_dim']
self.attention_dim = model_config['attention_dim']
self.num_layers = model_config['num_layers']
self.num_heads = model_config['num_heads']
self.dropout_rate = model_config['dropout_rate']

# Initialize embedding variable in constructor to allow reuse by other models
if shared_embedding is not None:
self.embed = shared_embedding
elif vocab_dim is None:
raise ValueError("Pass either a vocabulary dimension or an embedding Variable")
raise ValueError('Pass either a vocabulary dimension or an embedding Variable')
else:
random_init = tf.random_normal_initializer(stddev=self.hidden_dim ** -0.5)
self.embed = tf.Variable(random_init([vocab_dim, self.hidden_dim]), dtype=tf.float32)
Expand All @@ -136,7 +135,7 @@ def build(self, _):
self.ln_out = LayerNormalization(self.hidden_dim)

# Two-layer feed-forward with wide layer in the middle
self.ff_1 = [tf.keras.layers.Dense(self.ff_dim, activation="relu") for _ in range(self.num_layers)]
self.ff_1 = [tf.keras.layers.Dense(self.ff_dim, activation='relu') for _ in range(self.num_layers)]
self.ff_2 = [tf.keras.layers.Dense(self.hidden_dim) for _ in range(self.num_layers)]

# Default 'call' applies standard self-attention, with dropout if training=True.
Expand Down Expand Up @@ -183,7 +182,7 @@ def enc_dec_attention(self, states, masks, key_states, key_masks, attention_bias
@tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])
def embed_inputs(self, inputs):
states = tf.nn.embedding_lookup(self.embed, inputs)
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], "float32"))
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], 'float32'))
states += self.pos_enc[:tf.shape(states)[1]]
return states

Expand Down
10 changes: 5 additions & 5 deletions models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
class RNN(tf.keras.layers.Layer):
def __init__(self, model_config, shared_embedding=None, vocab_dim=None):
super(RNN, self).__init__()
self.hidden_dim = model_config["hidden_dim"]
self.num_layers = model_config["num_layers"]
self.dropout_rate = model_config["dropout_rate"]
self.hidden_dim = model_config['hidden_dim']
self.num_layers = model_config['num_layers']
self.dropout_rate = model_config['dropout_rate']

# Initialize embedding variable in constructor to allow reuse by other models
if shared_embedding is not None:
self.embed = shared_embedding
elif vocab_dim is None:
raise ValueError("Pass either a vocabulary dimension or an embedding Variable")
raise ValueError('Pass either a vocabulary dimension or an embedding Variable')
else:
random_init = tf.random_normal_initializer(stddev=self.hidden_dim ** -0.5)
self.embed = tf.Variable(random_init([vocab_dim, self.hidden_dim]), dtype=tf.float32)
Expand All @@ -37,5 +37,5 @@ def call(self, states, training):
@tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])
def embed_inputs(self, inputs):
states = tf.nn.embedding_lookup(self.embed, inputs)
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], "float32"))
states *= tf.math.sqrt(tf.cast(tf.shape(states)[-1], 'float32'))
return states
Loading

0 comments on commit 67664c3

Please sign in to comment.