Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pre-tokenized streaming dataset finetuning #601

Closed
wants to merge 11 commits into from
38 changes: 27 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from typing import Any, Callable, Dict, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import torch
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand All @@ -46,25 +48,39 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']

def _read_binary_tokenized_sample(
sample: Dict[str, Any]) -> Dict[str, torch.Tensor]:
example = {
'input_ids':
torch.from_numpy(
np.frombuffer(sample['prompt'], dtype=np.int64).copy()),
'labels':
torch.from_numpy(
np.frombuffer(sample['response'], dtype=np.int64).copy()),
}
example['attention_mask'] = torch.ones(example['input_ids'].size())
return example


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
example: Dict[str, Any], tokenizer: PreTrainedTokenizerBase
) -> Union[Dict[str, List[int]], Dict[str, torch.Tensor]]:
if ('prompt' not in example) or ('response' not in example):
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
)
if not isinstance(example['prompt'], str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
)
if not isinstance(example['response'], str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])
if isinstance(example['prompt'], str) and isinstance(
example['response'], str):
return tokenizer(text=example['prompt'],
text_target=example['response'])
elif isinstance(example['prompt'], bytes) and isinstance(
example['response'], bytes):
return _read_binary_tokenized_sample(example)
else:
raise TypeError('Unable to process example. Both "prompt" and "response" ' +\
'either need to be a string or byte array')


class StreamingFinetuningDataset(StreamingDataset):
Expand Down
86 changes: 86 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional
from unittest.mock import MagicMock

import numpy
import pytest
import torch
import transformers
Expand Down Expand Up @@ -67,6 +68,25 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str):
output_writer.write(sample)


def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str):
columns = {'prompt': 'bytes', 'response': 'bytes'}

dataset = []

for i in range(0, 64):
dataset.append({
'prompt': numpy.asarray([i, i, i, i]).tobytes(),
'response': numpy.asarray([i + 1, i + 1, i + 1, i + 1]).tobytes()
})

output_path = os.path.join(data_path, split)

with MDSWriter(columns=columns, out=output_path,
compression=None) as output_writer:
for sample in dataset:
output_writer.write(sample)


@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@pytest.mark.parametrize('pretokenize', [False, True])
def test_correct_padding(tokenizer_name: str,
Expand Down Expand Up @@ -480,6 +500,72 @@ def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
_ = build_finetuning_dataloader(cfg, tokenizer, 4)


def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path):
remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

build_mock_tokenized_ft_streaming_dataset(remote_path, 'train')

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': 2048},
)

ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 32)

expected_keys = ['input_ids', 'attention_mask', 'labels']

batch_idx = 0
for batch in ft_dataloader:
for k in expected_keys:
assert k in batch
boomanaiden154 marked this conversation as resolved.
Show resolved Hide resolved
t = batch[k]
if batch_idx == 0:
if k == 'input_ids':
for i in range(0, 32):
bi = batch_idx * 32 + i
# Only check the first four elements. The rest will be
# padding functions up to the maximum sequence length
# introduced by the collator
assert torch.equal(t[i][:4],
torch.tensor([bi, bi, bi, bi]))
if k == 'labels':
for i in range(0, 32):
bi = batch_idx * 32 + i + 1
# Look at indicies 4-8 as the collator pads the labels
# and the actual labels end up in these positions.
assert torch.equal(t[i][4:8],
torch.tensor([bi, bi, bi, bi]))
if k == 'attention_mask':
for i in range(0, 32):
# We only have four tokens per batch, so the attention
# mask should have 1s in the first four positions.
assert torch.equal(t[i][:4], torch.ones(4))
batch_idx += 1


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
Expand Down
Loading