Skip to content

Commit

Permalink
Make autopacking faster
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Aug 8, 2024
1 parent f006d07 commit 58790d8
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 177 deletions.
67 changes: 46 additions & 21 deletions llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class Seq2SeqFinetuningCollator:
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
batch_metadata (dict, optional): A dictionary of metadata which will be added
to the batch.
pad_to_longest (bool, optional): Whether to pad to the longest sequence,
which may result in smaller but inconsistent batch sizes. This is
primarily used to profile packing.
Default: ``False`` ensures that all sequences are max_seq_len.
"""

def __init__(
Expand All @@ -235,6 +239,7 @@ def __init__(
target_prompts: str = 'none',
allow_pad_trimming: bool = False,
batch_metadata: Optional[Dict[str, Any]] = None,
pad_to_longest: bool = False,
):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand All @@ -247,6 +252,8 @@ def __init__(
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False

self._pad_to_longest = pad_to_longest

illegal_keys = [
'input_ids',
'labels',
Expand Down Expand Up @@ -320,24 +327,33 @@ def _process_and_batch_decoder_only(
) -> Dict[str, torch.Tensor]:
# Steps explained in comments
processed_examples = []
for example in examples:
input_ids, labels = stitch_turns_decoder_only(
input_ids_and_labels = [
stitch_turns_decoder_only(
example_turns=example['turns'],
target_prompts=self.target_prompts,
target_responses=self.target_responses,
eos_token_id=self.tokenizer.eos_token_id,
)
) for example in examples
]

if self._pad_to_longest:
max_seq_len = max([
len(input_ids) for input_ids, _ in input_ids_and_labels
])
else:
max_seq_len = self.max_seq_len

for input_ids, labels in input_ids_and_labels:
orig_size = len(input_ids)
# We may need to truncate the input_ids / labels in order to maintain max_seq_len
if orig_size > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
labels = labels[:self.max_seq_len]
if orig_size > max_seq_len:
input_ids = input_ids[:max_seq_len]
labels = labels[:max_seq_len]

# Check to make sure there are still loss-generating tokens. Error if not.
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
raise ValueError(
f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' +\
f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\
f'Pre-truncation sequence length was {orig_size}. ' +\
'This sample should have been filtered out before reaching the collator. If using ' +\
'pre-tokenized streaming data, this may have resulted from using different ' +\
Expand All @@ -348,7 +364,7 @@ def _process_and_batch_decoder_only(
# Still issue a warning when truncating
if not self._warned_truncated:
warnings.warn(
f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' +\
f'Truncating sequence of length={orig_size} to fit max_seq_len={max_seq_len}. ' +\
f'If truncation is a problem, consider increasing max_seq_len.',
)
self._warned_truncated = True
Expand All @@ -358,7 +374,7 @@ def _process_and_batch_decoder_only(
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
Expand All @@ -376,7 +392,7 @@ def _process_and_batch_decoder_only(
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
max_length=max_seq_len,
return_tensors='pt',
)

Expand Down Expand Up @@ -410,35 +426,44 @@ def _process_and_batch_encoder_decoder(
# The encoder-decoder case is has some gotchas.
# Steps are explained in comments.
processed_examples = []
for example in examples:
context, target = stitch_turns_encoder_decoder(
contexts_and_targets = [
stitch_turns_encoder_decoder(
example_turns=example['turns'],
eos_token_id=self.tokenizer.eos_token_id,
)
) for example in examples
]

if self._pad_to_longest:
max_seq_len = 0
for context, target in contexts_and_targets:
max_seq_len = max(max_seq_len, len(context), len(target))
else:
max_seq_len = self.max_seq_len

for context, target in contexts_and_targets:
# We need to pad labels ourselves. Because HF.
if len(target) < self.max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
if len(target) < max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
warnings.warn(
f'Truncating TARGET sequence of length={len(target)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'to max_seq_len={max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[:self.max_seq_len -
target = target[:max_seq_len -
1] + [self.tokenizer.eos_token_id]

# We might need to truncate the context. Preserve the beginning.
if len(context) > self.max_seq_len:
if len(context) > max_seq_len:
if not self._warned_context:
warnings.warn(
f'Truncating CONTEXT sequence of length={len(context)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'to max_seq_len={max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_context = True
context = context[:self.max_seq_len -
context = context[:max_seq_len -
1] + [self.tokenizer.eos_token_id]

# Back into the example
Expand All @@ -454,7 +479,7 @@ def _process_and_batch_encoder_decoder(
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
max_length=max_seq_len,
return_tensors='pt',
)
# We're still missing decoder_input_ids and decoder_attention_mask
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
'pad_to_longest',
}


Expand Down Expand Up @@ -630,6 +631,7 @@ def build_collate_fn(
max_seq_len = dataset_cfg['max_seq_len']
decoder_only_format = dataset_cfg['decoder_only_format']
allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False)
pad_to_longest = dataset_cfg.get('pad_to_longest', False)

collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
Expand All @@ -638,6 +640,7 @@ def build_collate_fn(
target_responses=target_responses,
target_prompts=target_prompts,
allow_pad_trimming=allow_pad_trimming,
pad_to_longest=pad_to_longest,
)

packing_ratio = dataset_cfg.get('packing_ratio')
Expand Down
Loading

0 comments on commit 58790d8

Please sign in to comment.