Skip to content

Commit

Permalink
Merge branch 'main' into jz/systemMetricsMonitor
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 6, 2024
2 parents 38d1db6 + 42c2d9a commit b2ddc1d
Show file tree
Hide file tree
Showing 25 changed files with 1,019 additions and 290 deletions.
4 changes: 4 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# This includes setup.py, the README, and the CODEOWNERS file itself!
/* @mosaicml/composer-team-admins

# Require team approval for code changes
/llmfoundry/ @mosaicml/composer-team-eng
/scripts/ @mosaicml/composer-team-eng

# Require admin approval to change the CI build configuration
# All CI Changes should be reviewed for security
/.ci/ @mosaicml/composer-team-admins
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py
RUN rm setup.py

# Install TransformerEngine
RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@05eb6deb31c1b48e9f4380d18fe95f3c38e84335
RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=3 MAX_JOBS=3 pip install git+https://github.com/cli99/TransformerEngine.git@6b21f606f2459d49c2113d69236d68d334edeb4c

# Install and uninstall foundry to cache foundry requirements
RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git
Expand Down
9 changes: 8 additions & 1 deletion llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.data import (
SUPPORTED_MDS_ENCODING_TYPES,
ConcatTokensDataset,
NoConcatDataset,
stream_remote_local_validate,
)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.finetuning import (
Seq2SeqFinetuningCollator,
Expand Down Expand Up @@ -55,4 +60,6 @@
'auto_packing_ratio',
'profile_packing',
'ConcatenatedSequenceCollatorWrapper',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]
109 changes: 83 additions & 26 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,32 @@
"""Datasets for converting to MDS Shards."""
import os
import warnings
from typing import Dict, Iterable, Union
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

__all__ = [
'AbstractConcatTokensDataset',
'ConcatTokensDataset',
'NoConcatDataset',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]

SUPPORTED_MDS_ENCODING_TYPES = [
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
'uint32',
'uint64',
]


Expand All @@ -35,39 +51,20 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
yield {'text': sample['text'].encode('utf-8')}


class ConcatTokensDataset(IterableDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': bytes}
To use data created by this class and written to MDS format:
class AbstractConcatTokensDataset(ABC, IterableDataset):
"""Abstract class for defining an IterableDataset that tokenizes and.
```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
concatenates text samples on the fly.
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.max_length = max_length
Expand Down Expand Up @@ -114,8 +111,47 @@ def __init__(
'in duplicated special tokens. Please be sure this is what you intend.',
)

def __iter__(self) -> Iterable[Dict[str, bytes]]:
@abstractmethod
def __iter__(self) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': ndarray:int32}
To use data created by this class and written to MDS format:
```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int32).copy())
print(tokenizer.decode(tokens))
```
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, NDArray]]:
buffer = []
for sample in self.hf_dataset:
encoded = self.tokenizer(
Expand All @@ -129,6 +165,27 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
# convert to bytes to store in MDS binary format
'tokens': np.asarray(concat_sample).tobytes(),
# convert to ndarray to store in MDS format
'tokens': np.asarray(concat_sample, dtype=np.int32),
}


def stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
"""Check that, if needed, the local/split directory exists.
Args:
remote (Optional[str]): Remote path to the dataset.
local (Optional[str]): Local path to the dataset.
split (Optional[str]): Subdirectory specifying which dataset split to use, if any.
"""
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)
62 changes: 35 additions & 27 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.data import (
SUPPORTED_MDS_ENCODING_TYPES,
stream_remote_local_validate,
)
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
Expand All @@ -69,6 +73,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ALLOWED_MESSAGES_KEYS,
ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
ConsecutiveRepeatedChatRolesError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
Expand Down Expand Up @@ -245,15 +250,22 @@ def slice_out_last_turn(
messages_through_current_turn: List[Dict[str, str]],
conversation_through_previous_turn: str,
) -> Tuple[str, str]:
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
)
try:
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
)
except Exception as e:
raise ChatTemplateError(
tokenizer.chat_template,
sample=messages_through_current_turn,
inner_message=str(e),
)
if conversation_through_previous_turn != full_conversation[:len(
conversation_through_previous_turn,
)]:
Expand Down Expand Up @@ -486,26 +498,15 @@ def is_valid_ift_example(
return True


def _stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
token_encoding_type (str): The encoding type of the tokenized samples. This is only used
for legacy datasets that have been written directly as 'bytes' instead of numpy
arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -566,6 +567,7 @@ class StreamingFinetuningDataset(StreamingDataset):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
Expand Down Expand Up @@ -598,11 +600,17 @@ def __init__(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
_stream_remote_local_validate(remote, local, split)
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(
stream_remote_local_validate(
stream.remote,
stream.local,
split,
Expand Down Expand Up @@ -648,11 +656,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample['input_ids'][:self.max_seq_len
Expand Down
Loading

0 comments on commit b2ddc1d

Please sign in to comment.