Skip to content

Commit

Permalink
Masked language modeling example (#2434)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Jul 6, 2024
1 parent 93e04a8 commit 49931d5
Show file tree
Hide file tree
Showing 9 changed files with 1,028 additions and 28 deletions.
121 changes: 121 additions & 0 deletions applications/nlp/transformer/datasets/thepile_mlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
The Pile dataset, stored as pre-tokenized binary files for optimized processing.
"""
import os
import os.path

import numpy as np
# ----------------------------------------------
# Options
# ----------------------------------------------

sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
mlm_probability = float(os.getenv('THE_PILE_MASK_PROB', default='0.15'))

# ----------------------------------------------
# Setup
# ----------------------------------------------

# Load the datasets
data_dir = os.getenv('THE_PILE_DATA_DIR',
'/p/vast1/data/datasets/the-pile-pretokenized')
dataset_train = np.memmap(os.path.join(data_dir, 'train.bin'),
dtype=np.uint16,
mode='r')
sample_lengths_train = np.fromfile(os.path.join(data_dir, 'train-seqlen.bin'),
dtype=np.uint32).astype(np.uint64)
sample_offsets_train = np.zeros_like(sample_lengths_train)
sample_offsets_train[1:] = np.cumsum(sample_lengths_train)[:-1]
dataset_val = np.memmap(os.path.join(data_dir, 'val.bin'),
dtype=np.uint16,
mode='r')
sample_lengths_val = np.fromfile(os.path.join(data_dir, 'val-seqlen.bin'),
dtype=np.uint32).astype(np.uint64)
sample_offsets_val = np.zeros_like(sample_lengths_val)
sample_offsets_val[1:] = np.cumsum(sample_lengths_val)[:-1]

# Uses the definition from the GPT-NeoX-20B tokenizer
pad_index = 1 # '<|padding|>'
mask_index = 0
_vocab_size = 50277

# ----------------------------------------------
# Sample access functions
# ----------------------------------------------


def make_mask(random: bool = True) -> np.ndarray:
# 0 = masked, 1 = not masked
if random:
return np.random.binomial(1, 1 - mlm_probability, size=sequence_length)

# All masked:
#return np.full((sequence_length, ), 0)
# Nothing masked:
return np.full((sequence_length, ), 1)

def trim_and_pad(sample, random: bool):
# Trim long sequences
if len(sample) > sequence_length:
if random:
pos = np.random.rand()
offset = (len(sample) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample = sample[offset:offset + sequence_length]
else:
sample = sample[0:sequence_length]

# Left-pad short sequences
if len(sample) < sequence_length:
sample_pad = np.full(sequence_length, pad_index, dtype=np.int32)
if len(sample) > 0:
sample_pad[-len(sample):] = sample
return sample_pad

return sample


def concat(*args):
return np.concatenate(tuple(a.flat for a in args))


def get_train_sample(index: int):
sample = np.copy(
dataset_train[sample_offsets_train[index]:sample_offsets_train[index] +
sample_lengths_train[index]]).astype(np.int32)
return concat(trim_and_pad(sample, True), make_mask())


def get_val_sample(index):
sample = np.copy(
dataset_val[sample_offsets_val[index]:sample_offsets_val[index] +
sample_lengths_val[index]]).astype(np.int32)
return concat(trim_and_pad(sample, False), make_mask())


def num_train_samples():
return sample_lengths_train.shape[0]


def num_val_samples():
return sample_lengths_val.shape[0]


def sample_dims():
return (sequence_length + sequence_length, )


def vocab_size():
return _vocab_size


if __name__ == '__main__':
print('Training samples:', num_train_samples())
print('Validation samples:', num_val_samples())
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(
os.path.join(data_dir, '20B_tokenizer.json'))
print('Training sample 101:')
print(tokenizer.decode(get_train_sample(101)))
print('Validation sample 233:')
print(tokenizer.decode(get_val_sample(233)))
157 changes: 157 additions & 0 deletions applications/nlp/transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,144 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
return result


def create_masked_language_modeling_transformer(
dataset, embed_dim: int, num_encoders: int, num_decoders: int, num_heads: int,
dropout: float, input_dropout: float, attn_dropout: float,
num_epochs: int, args: argparse.Namespace):
"""
Creates a flexible transformer for masked language modeling tasks.
"""
sequence_length = dataset.sequence_length
vocab_size = dataset.vocab_size()

# Embedding weights
var = 2 / (embed_dim + vocab_size) # Glorot initialization
embedding_weights = lbann.Weights(
name='embeddings',
initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
)

# Input is a sequences of token IDs followed by a mask sequence
all_inputs = lbann.Input(data_field='samples')
slice_points = [
0,
sequence_length, # Original sequence
2 * sequence_length, # Mask
]
if args.attn_mask:
# Attention matrix mask
slice_points.append(2 * sequence_length + sequence_length * sequence_length)

slc = lbann.Slice(
all_inputs,
slice_points=slice_points,
)
input_tokens = lbann.Identity(slc)
mask = lbann.Identity(slc)
if args.attn_mask:
attn = lbann.Reshape(lbann.Identity(slc),
dims=[sequence_length, sequence_length])
else:
attn = None

masked_input = lbann.Select(mask,
input_tokens,
value=1,
if_false=dataset.mask_index)

# Get sequences of embedding vectors
embeddings = lbann.Embedding(
masked_input,
weights=embedding_weights,
num_embeddings=vocab_size,
embedding_dim=embed_dim,
padding_idx=dataset.pad_index,
)
decoder_input = lbann.WeightedSum(
embeddings,
scaling_factors=math.sqrt(embed_dim),
)

petype = InputEncoding[args.positional_encoding.upper()]

# Apply input encoding
_, decoder_input, posenc = _add_input_encoding(None, decoder_input, petype,
embed_dim, input_dropout, 0,
sequence_length, num_heads)

# Add a GPT-style (decoder-only) transformer model
transformer = Transformer(hidden_size=embed_dim,
num_heads=num_heads,
dropout=dropout,
attn_dropout=attn_dropout,
num_encoder_layers=num_encoders,
num_decoder_layers=num_decoders,
pre_layernorm=True,
activation=lbann.Gelu,
positional_encoding=posenc,
name='transformer')

# Tessellate attention pattern for all heads (note that this is a memory issue)
if attn is not None and not transformer.separate_heads:
# TODO(later): Use broadcasting semantics to save memory
attn = lbann.Reshape(attn, dims=[1, sequence_length, sequence_length])
attn = lbann.Tessellate(
attn, dims=[num_heads, sequence_length, sequence_length])

# Apply parallelism techniques
transformer, extra_model_kwargs = parallelism.apply_subgraph_parallelism(
transformer, args)
parallelism.apply_ffn_model_parallelism(transformer, args)
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
parallelism.apply_layer_parallelism(transformer, args)

# Run through transformer with the same sequence
result = transformer(decoder_input,
decoder_input,
sequence_length,
target_mask=attn)

# Apply layer normalization on the outputs
norm_final = LayerNorm(embed_dim, name=f'final_layernorm')
result = norm_final(result)

# Apply language modeling head on results
lm_head = lbann.ChannelwiseFullyConnected(result,
weights=embedding_weights,
output_channel_dims=[vocab_size],
bias=False,
transpose=True,
name="prediction_layer")
preds = lbann.ChannelwiseSoftmax(lm_head)
preds = lbann.TensorPermute(preds, axes=[1, 0])

# Compute loss
loss = _add_mlm_loss(preds, input_tokens, sequence_length, vocab_size,
dataset.pad_index)

parallelism.apply_lm_head_model_parallelism(lm_head, args)

# Construct model
metrics = []
callbacks = [
lbann.CallbackPrint(),
lbann.CallbackTimer(),
lbann.CallbackGPUMemoryUsage()
]
result = lbann.Model(
num_epochs,
layers=lbann.traverse_layer_graph(input_tokens),
objective_function=loss,
metrics=metrics,
callbacks=callbacks,
**extra_model_kwargs,
)

parallelism.apply_fsdp_allweights(result, args)
parallelism.apply_layer_parallelism_postamble(result, args)
return result


def _add_input_encoding(
encoder_input: lbann.Layer, decoder_input: lbann.Layer,
encoding_kind: InputEncoding, embed_dim: int, input_dropout: float,
Expand Down Expand Up @@ -323,6 +461,25 @@ def _add_autoregressive_loss(preds, input_tokens, sequence_length, vocab_size,
return lbann.Scale(ce, constant=1 / (sequence_length - 1))


def _add_mlm_loss(preds, input_tokens, sequence_length, vocab_size, pad_index):
# Compute cross-entropy loss between preds and the original tokens from a
# masked input

# Flatten labels
flat_labels = lbann.Reshape(input_tokens, dims=[1, sequence_length])

# Filter out output predictions that are in padding from cross-entropy by
# using values that will never contribute to the cross-entropy loss
flat_labels = lbann.Select(flat_labels,
lbann.Identity(flat_labels),
value=pad_index,
if_true=(vocab_size + 1))

# Compute mean cross-entropy over the sequence
ce = lbann.CrossEntropy(preds, flat_labels, use_labels=True)
return lbann.Scale(ce, constant=1 / sequence_length)


# Command-line arguments
def add_transformer_architecture_arguments(args: argparse.Namespace):
"""
Expand Down
3 changes: 2 additions & 1 deletion applications/nlp/transformer/pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def main():
lbann.contrib.args.add_scheduler_arguments(parser, 'lbann_gpt')
lbann.contrib.args.add_profiling_arguments(parser)
lbann.contrib.args.add_training_arguments(parser,
default_minibatch_size=32)
default_minibatch_size=32,
default_epochs=1)
lbann.contrib.args.add_amp_arguments(parser)
parallelism.add_transformer_parallelism_arguments(parser)
trainer.add_training_arguments(parser)
Expand Down
Loading

0 comments on commit 49931d5

Please sign in to comment.