From 83d36a5a7cbbca60aabb70620daf63c5c7a136e6 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Fri, 15 Sep 2023 13:39:54 -0700 Subject: [PATCH 1/4] Add support for pre-tokenized SD finetuning --- llmfoundry/data/finetuning/tasks.py | 24 ++++++++++- tests/test_dataloader.py | 62 +++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0a2b386048..9b0939e73a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,6 +38,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets +import numpy as np +import torch from omegaconf import DictConfig from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase @@ -67,6 +69,19 @@ def _tokenize_formatted_example( return tokenizer(text=example['prompt'], text_target=example['response']) +def _read_binary_tokenized_sample(sample: Dict[str, Any]): + example = { + 'input_ids': + torch.from_numpy( + np.frombuffer(sample['tokens'], dtype=np.int64).copy()), + 'labels': + torch.from_numpy( + np.frombuffer(sample['labels'], dtype=np.int64).copy()), + } + example['attention_mask'] = torch.ones(example['input_ids'].size()) + return example + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. @@ -185,7 +200,14 @@ def __init__(self, # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: sample = super().__getitem__(idx) - return _tokenize_formatted_example(sample, tokenizer=self.tokenizer) + if 'prompt' in sample and 'response' in sample: + return _tokenize_formatted_example(sample, tokenizer=self.tokenizer) + elif 'tokens' in sample and 'labels' in sample: + return _read_binary_tokenized_sample(sample) + else: + raise RuntimeError( + 'FineTurningDataset needs samples to have prompt/response columns ' + 'or tokens/labels columns') class DatasetConstructor: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 72bfac1d08..0e6066440f 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -11,6 +11,7 @@ import pytest import torch +import numpy from composer.utils import dist, using_torch_2 from omegaconf import OmegaConf as om from streaming import MDSWriter @@ -62,6 +63,25 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str): output_writer.write(sample) +def build_mock_tokenized_ft_streaming_dataset(data_path, split): + columns = {'tokens': 'bytes', 'labels': 'bytes'} + + dataset = [{ + 'tokens': numpy.asarray([1, 2, 3, 4]).tobytes(), + 'labels': numpy.asarray([2, 3, 4, 5]).tobytes() + }, { + 'tokens': numpy.asarray([2, 3, 4, 5]).tobytes(), + 'labels': numpy.asarray([3, 4, 5, 6]).tobytes() + }] + + output_path = os.path.join(data_path, split) + + with MDSWriter(columns=columns, out=output_path, + compression=None) as output_writer: + for sample in dataset: + output_writer.write(sample) + + @pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m']) @pytest.mark.parametrize('pretokenize', [False, True]) def test_correct_padding(tokenizer_name: str, @@ -472,6 +492,48 @@ def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path): _ = build_finetuning_dataloader(cfg, tokenizer, 4) +def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path): + remote_path = os.path.join(tmp_path, 'remote') + local_path = os.path.join(tmp_path, 'local') + + build_mock_tokenized_ft_streaming_dataset(remote_path, 'train') + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'remote': remote_path, + 'local': local_path, + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': 2048}, + ) + + ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 4) + + expected_keys = ['input_ids', 'attention_mask', 'labels'] + + for batch in ft_dataloader: + for k in expected_keys: + assert k in batch + + @pytest.mark.parametrize('add_bad_data_dropped', [True, False]) @pytest.mark.parametrize('add_bad_data_error', [True, False]) def test_malformed_data( From 1831916b7fba79fb86ccdfce9abe5ef1fe53c7fb Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Wed, 20 Sep 2023 14:46:06 -0700 Subject: [PATCH 2/4] Fix pyright warnings/errors This patch fixes a pyright string concatenation warning and also adds typing information where necessary. --- llmfoundry/data/finetuning/tasks.py | 2 +- tests/test_dataloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 9b0939e73a..0df52c6f51 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -206,7 +206,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return _read_binary_tokenized_sample(sample) else: raise RuntimeError( - 'FineTurningDataset needs samples to have prompt/response columns ' + 'FineTurningDataset needs samples to have prompt/response columns ' +\ 'or tokens/labels columns') diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 0e6066440f..0fb0870f97 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -63,7 +63,7 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str): output_writer.write(sample) -def build_mock_tokenized_ft_streaming_dataset(data_path, split): +def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str): columns = {'tokens': 'bytes', 'labels': 'bytes'} dataset = [{ From 8f0e308a4a7f13216947fe692c10d0812f221bb1 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Mon, 25 Sep 2023 15:23:49 -0700 Subject: [PATCH 3/4] Address reviewer feedback --- llmfoundry/data/finetuning/tasks.py | 53 +++++++++++++---------------- tests/test_dataloader.py | 12 +++---- 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0df52c6f51..0ecfdce3ce 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -49,27 +49,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] -def _tokenize_formatted_example( - example: Dict[str, Any], - tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: - if ('prompt' not in example) or ('response' not in example): - raise KeyError( - 'Unable to tokenize example because it has not been properly formatted. ' +\ - '"prompt" and "response" are required keys but at least one was missing ' +\ - f'from {example=}.' - ) - if not isinstance(example['prompt'], str): - raise TypeError( - f'Unable to tokenize example because "prompt" was not a string. {example=}' - ) - if not isinstance(example['response'], str): - raise TypeError( - f'Unable to tokenize example because "response" was not a string. {example=}' - ) - return tokenizer(text=example['prompt'], text_target=example['response']) - - -def _read_binary_tokenized_sample(sample: Dict[str, Any]): +def _read_binary_tokenized_sample( + sample: Dict[str, Any]) -> Dict[str, torch.Tensor]: example = { 'input_ids': torch.from_numpy( @@ -82,6 +63,27 @@ def _read_binary_tokenized_sample(sample: Dict[str, Any]): return example +def _tokenize_formatted_example( + example: Dict[str, Any], tokenizer: PreTrainedTokenizerBase +) -> Union[Dict[str, List[int]], Dict[str, torch.Tensor]]: + if ('prompt' not in example) or ('response' not in example): + raise KeyError( + 'Unable to tokenize example because it has not been properly formatted. ' +\ + '"prompt" and "response" are required keys but at least one was missing ' +\ + f'from {example=}.' + ) + if isinstance(example['prompt'], str) and isinstance( + example['response'], str): + return tokenizer(text=example['prompt'], + text_target=example['response']) + elif isinstance(example['prompt'], bytes) and isinstance( + example['response'], bytes): + return _read_binary_tokenized_sample(example) + else: + raise TypeError('Unable to process example. Both "prompt" and "response" ' +\ + 'either need to be a string or byte array') + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. @@ -200,14 +202,7 @@ def __init__(self, # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: sample = super().__getitem__(idx) - if 'prompt' in sample and 'response' in sample: - return _tokenize_formatted_example(sample, tokenizer=self.tokenizer) - elif 'tokens' in sample and 'labels' in sample: - return _read_binary_tokenized_sample(sample) - else: - raise RuntimeError( - 'FineTurningDataset needs samples to have prompt/response columns ' +\ - 'or tokens/labels columns') + return _tokenize_formatted_example(sample, tokenizer=self.tokenizer) class DatasetConstructor: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 61f1cb8233..9844cc1431 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -9,9 +9,9 @@ from argparse import Namespace from typing import Optional +import numpy import pytest import torch -import numpy from composer.utils import dist, using_torch_2 from omegaconf import OmegaConf as om from streaming import MDSWriter @@ -64,14 +64,14 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str): def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str): - columns = {'tokens': 'bytes', 'labels': 'bytes'} + columns = {'prompt': 'bytes', 'response': 'bytes'} dataset = [{ - 'tokens': numpy.asarray([1, 2, 3, 4]).tobytes(), - 'labels': numpy.asarray([2, 3, 4, 5]).tobytes() + 'prompt': numpy.asarray([1, 2, 3, 4]).tobytes(), + 'response': numpy.asarray([2, 3, 4, 5]).tobytes() }, { - 'tokens': numpy.asarray([2, 3, 4, 5]).tobytes(), - 'labels': numpy.asarray([3, 4, 5, 6]).tobytes() + 'prompt': numpy.asarray([2, 3, 4, 5]).tobytes(), + 'response': numpy.asarray([3, 4, 5, 6]).tobytes() }] output_path = os.path.join(data_path, split) From 9cceb706493d47fe738b70b710a5e54de33ccbd0 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Thu, 28 Sep 2023 23:38:25 -0700 Subject: [PATCH 4/4] Fix FT dataset loading after key name change, make test more detailed --- llmfoundry/data/finetuning/tasks.py | 5 ++-- tests/test_dataloader.py | 40 +++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6d51f20db4..baa3052cad 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -48,16 +48,15 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] - def _read_binary_tokenized_sample( sample: Dict[str, Any]) -> Dict[str, torch.Tensor]: example = { 'input_ids': torch.from_numpy( - np.frombuffer(sample['tokens'], dtype=np.int64).copy()), + np.frombuffer(sample['prompt'], dtype=np.int64).copy()), 'labels': torch.from_numpy( - np.frombuffer(sample['labels'], dtype=np.int64).copy()), + np.frombuffer(sample['response'], dtype=np.int64).copy()), } example['attention_mask'] = torch.ones(example['input_ids'].size()) return example diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 7580f88ecf..85ae62e303 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -66,13 +66,13 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str): def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str): columns = {'prompt': 'bytes', 'response': 'bytes'} - dataset = [{ - 'prompt': numpy.asarray([1, 2, 3, 4]).tobytes(), - 'response': numpy.asarray([2, 3, 4, 5]).tobytes() - }, { - 'prompt': numpy.asarray([2, 3, 4, 5]).tobytes(), - 'response': numpy.asarray([3, 4, 5, 6]).tobytes() - }] + dataset = [] + + for i in range(0, 64): + dataset.append({ + 'prompt': numpy.asarray([i, i, i, i]).tobytes(), + 'response': numpy.asarray([i + 1, i + 1, i + 1, i + 1]).tobytes() + }) output_path = os.path.join(data_path, split) @@ -527,13 +527,37 @@ def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path): tokenizer_kwargs={'model_max_length': 2048}, ) - ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 4) + ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 32) expected_keys = ['input_ids', 'attention_mask', 'labels'] + batch_idx = 0 for batch in ft_dataloader: for k in expected_keys: assert k in batch + t = batch[k] + if batch_idx == 0: + if k == 'input_ids': + for i in range(0, 32): + bi = batch_idx * 32 + i + # Only check the first four elements. The rest will be + # padding functions up to the maximum sequence length + # introduced by the collator + assert torch.equal(t[i][:4], + torch.tensor([bi, bi, bi, bi])) + if k == 'labels': + for i in range(0, 32): + bi = batch_idx * 32 + i + 1 + # Look at indicies 4-8 as the collator pads the labels + # and the actual labels end up in these positions. + assert torch.equal(t[i][4:8], + torch.tensor([bi, bi, bi, bi])) + if k == 'attention_mask': + for i in range(0, 32): + # We only have four tokens per batch, so the attention + # mask should have 1s in the first four positions. + assert torch.equal(t[i][:4], torch.ones(4)) + batch_idx += 1 @pytest.mark.parametrize('add_bad_data_dropped', [True, False])