Skip to content

Commit

Permalink
Merge branch 'main' into milo/default-eval-interval
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Aug 9, 2024
2 parents 56bbd55 + 44b09f0 commit df47dfa
Show file tree
Hide file tree
Showing 27 changed files with 492 additions and 323 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
include:
- name: "2.3.1_cu121"
base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
dep_groups: "[gpu]"
dep_groups: "[all]"
- name: "2.3.1_cu121_aws"
base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws
dep_groups: "[gpu]"
dep_groups: "[all]"
steps:

- name: Checkout
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ FROM $BASE_IMAGE
ARG BRANCH_NAME
ARG DEP_GROUPS

ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0"

# Check for changes in setup.py.
# If there are changes, the docker cache is invalidated and a fresh pip installation is triggered.
ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py setup.py
Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import math
import os
Expand Down Expand Up @@ -28,6 +29,7 @@
merge_shard_groups,
)
from llmfoundry.utils.exceptions import (
DatasetTooSmallError,
InputFolderMissingDataError,
OutputFolderNotEmptyError,
)
Expand Down Expand Up @@ -468,6 +470,11 @@ def convert_text_to_mds(
trust_remote_code,
)

index_path = os.path.join(local_output_folder, 'index.json')
with open(index_path, 'r') as index_file:
if not json.load(index_file)['shards']:
raise DatasetTooSmallError()

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)

Expand Down
68 changes: 47 additions & 21 deletions llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class Seq2SeqFinetuningCollator:
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
batch_metadata (dict, optional): A dictionary of metadata which will be added
to the batch.
pad_to_longest (bool, optional): Whether to pad to the longest sequence,
which may result in smaller but inconsistent batch sizes. This is
primarily used to profile packing.
Default: ``False`` ensures that all sequences are max_seq_len.
"""

def __init__(
Expand All @@ -235,6 +239,7 @@ def __init__(
target_prompts: str = 'none',
allow_pad_trimming: bool = False,
batch_metadata: Optional[Dict[str, Any]] = None,
pad_to_longest: bool = False,
):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand All @@ -247,6 +252,8 @@ def __init__(
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False

self._pad_to_longest = pad_to_longest

illegal_keys = [
'input_ids',
'labels',
Expand Down Expand Up @@ -320,24 +327,34 @@ def _process_and_batch_decoder_only(
) -> Dict[str, torch.Tensor]:
# Steps explained in comments
processed_examples = []
for example in examples:
input_ids, labels = stitch_turns_decoder_only(
input_ids_and_labels = [
stitch_turns_decoder_only(
example_turns=example['turns'],
target_prompts=self.target_prompts,
target_responses=self.target_responses,
eos_token_id=self.tokenizer.eos_token_id,
)
) for example in examples
]

if self._pad_to_longest:
max_seq_len = max([
len(input_ids) for input_ids, _ in input_ids_and_labels
])
max_seq_len = min(max_seq_len, self.max_seq_len)
else:
max_seq_len = self.max_seq_len

for input_ids, labels in input_ids_and_labels:
orig_size = len(input_ids)
# We may need to truncate the input_ids / labels in order to maintain max_seq_len
if orig_size > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
labels = labels[:self.max_seq_len]
if orig_size > max_seq_len:
input_ids = input_ids[:max_seq_len]
labels = labels[:max_seq_len]

# Check to make sure there are still loss-generating tokens. Error if not.
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
raise ValueError(
f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' +\
f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\
f'Pre-truncation sequence length was {orig_size}. ' +\
'This sample should have been filtered out before reaching the collator. If using ' +\
'pre-tokenized streaming data, this may have resulted from using different ' +\
Expand All @@ -348,7 +365,7 @@ def _process_and_batch_decoder_only(
# Still issue a warning when truncating
if not self._warned_truncated:
warnings.warn(
f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' +\
f'Truncating sequence of length={orig_size} to fit max_seq_len={max_seq_len}. ' +\
f'If truncation is a problem, consider increasing max_seq_len.',
)
self._warned_truncated = True
Expand All @@ -358,7 +375,7 @@ def _process_and_batch_decoder_only(
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
Expand All @@ -376,7 +393,7 @@ def _process_and_batch_decoder_only(
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
max_length=max_seq_len,
return_tensors='pt',
)

Expand Down Expand Up @@ -410,35 +427,44 @@ def _process_and_batch_encoder_decoder(
# The encoder-decoder case is has some gotchas.
# Steps are explained in comments.
processed_examples = []
for example in examples:
context, target = stitch_turns_encoder_decoder(
contexts_and_targets = [
stitch_turns_encoder_decoder(
example_turns=example['turns'],
eos_token_id=self.tokenizer.eos_token_id,
)
) for example in examples
]

if self._pad_to_longest:
max_seq_len = 0
for context, target in contexts_and_targets:
max_seq_len = max(max_seq_len, len(context), len(target))
else:
max_seq_len = self.max_seq_len

for context, target in contexts_and_targets:
# We need to pad labels ourselves. Because HF.
if len(target) < self.max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
if len(target) < max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
warnings.warn(
f'Truncating TARGET sequence of length={len(target)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'to max_seq_len={max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[:self.max_seq_len -
target = target[:max_seq_len -
1] + [self.tokenizer.eos_token_id]

# We might need to truncate the context. Preserve the beginning.
if len(context) > self.max_seq_len:
if len(context) > max_seq_len:
if not self._warned_context:
warnings.warn(
f'Truncating CONTEXT sequence of length={len(context)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'to max_seq_len={max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_context = True
context = context[:self.max_seq_len -
context = context[:max_seq_len -
1] + [self.tokenizer.eos_token_id]

# Back into the example
Expand All @@ -454,7 +480,7 @@ def _process_and_batch_encoder_decoder(
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
max_length=max_seq_len,
return_tensors='pt',
)
# We're still missing decoder_input_ids and decoder_attention_mask
Expand Down
Loading

0 comments on commit df47dfa

Please sign in to comment.