From 4fa2dd88e2064e833c7c8e4f64734f0ef8d22b48 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 16 Oct 2023 18:23:24 -0700 Subject: [PATCH] Convert to DataSpec and add token counts that include padding (#676) --- llmfoundry/data/denoising.py | 18 ++- llmfoundry/data/finetuning/dataloader.py | 16 +- llmfoundry/data/packing.py | 2 +- llmfoundry/data/text_data.py | 61 +++++++- tests/test_dataloader.py | 189 ++++++++++++++++++++++- 5 files changed, 269 insertions(+), 17 deletions(-) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index d685d0077d..bc41945076 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -10,13 +10,15 @@ import numpy as np import torch +from composer.core.data_spec import DataSpec from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase from llmfoundry.data.packing import BinPackWrapper -from llmfoundry.data.text_data import StreamingTextDataset +from llmfoundry.data.text_data import (StreamingTextDataset, + get_tokens_per_batch_func) from llmfoundry.models import utils __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader'] @@ -353,7 +355,7 @@ def build_text_denoising_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader[Dict]: +) -> DataSpec: """Constructor function for a Mixture of Denoisers dataloader. This function constructs a dataloader that can be used to train an @@ -506,7 +508,7 @@ def build_text_denoising_dataloader( 'but cfg.dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -518,6 +520,12 @@ def build_text_denoising_dataloader( timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id, + decoder_only=cfg.mixture_of_denoisers.decoder_only_format) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def noise_token_sequence( example: Union[torch.Tensor, Mapping[str, Any]], @@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only( tokenizer = build_tokenizer(tokenizer_name=tokenizer_name, tokenizer_kwargs=tokenizer_kwargs) - loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size) + loader = build_text_denoising_dataloader(cfg, tokenizer, + device_batch_size).dataloader + assert isinstance(loader, DataLoader) assert isinstance(loader.dataset, StreamingTextDataset) print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index ebb7991dde..2dde563ac6 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -6,6 +6,7 @@ import datasets as hf_datasets import torch +from composer.core.data_spec import DataSpec from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig from torch.utils.data import DataLoader @@ -14,6 +15,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.text_data import get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int) -> DataLoader: + device_batch_size: int) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. The underlying dataset can be built through one of two code paths: @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig, collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) assert dataset is not None - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig, timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def _validate_config(dataset_cfg: DictConfig) -> None: """Validates the dataset configuration. @@ -442,7 +449,8 @@ def _build_collate_fn( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) device_batch_size = 2 - dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + dataloader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader packing = cfg.dataset.get('packing_ratio') is not None diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index d0a73be801..1532de276e 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, dataloader_cfg.dataset.packing_ratio = None dataloader_cfg.dataset.max_leftovers_to_keep = None train_dataloader = build_dataloader(dataloader_cfg, tokenizer, - max(raw_batch_sizes) * 100) + max(raw_batch_sizes) * 100).dataloader # Get a bunch of raw examples big_batch = next(iter(train_dataloader)) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index afdd243adf..93af2f63ed 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -11,6 +11,8 @@ import numpy as np import torch import transformers +from composer.core.data_spec import DataSpec +from composer.core.types import Batch from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset @@ -237,7 +239,7 @@ def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader: +) -> DataSpec: assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' if cfg.dataset.get('group_method', None) is not None: raise NotImplementedError( @@ -281,7 +283,7 @@ def build_text_dataloader( eos_token_id=eos_token_id, bos_token_id=bos_token_id) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -293,6 +295,58 @@ def build_text_dataloader( timeout=cfg.get('timeout', 0), ) + # If we pretokenized, we may not have padding, in which case the + # tokenizer may not have a pad_token_id. In this case, we can + # just use the default token counting function. This is correct + # because we do not support training on pretokenized data with padding, + # and if tokenizing on the fly, we require that the tokenizer has a pad token. + token_counting_func = None + if tokenizer.pad_token_id is not None: + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + + +def get_tokens_per_batch_func(pad_token_id: int, + decoder_only: bool = True + ) -> Callable[[Batch], int]: + """Returns a callable that counts the number of tokens in a batch. + + Args: + pad_token_id (int): The id of the padding token. + decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only) + or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``. + + Returns: + Callable[[Batch], int]: A callable that counts the number of tokens in a batch. + """ + + def get_num_samples_in_batch(batch: Batch) -> int: + if not isinstance(batch, Mapping) or 'input_ids' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() requires a batch with an input_ids key' + ) + + if not decoder_only and 'decoder_input_ids' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key' + ) + + # Count number of non padding tokens in batch + input_ids_tokens = int( + torch.sum(batch['input_ids'] != pad_token_id).item()) + + # For encoder decoder models only + decoder_input_ids_tokens = 0 + if not decoder_only: + decoder_input_ids_tokens = int( + torch.sum(batch['decoder_input_ids'] != pad_token_id).item()) + + return input_ids_tokens + decoder_input_ids_tokens + + return get_num_samples_in_batch + # Helpful to test if your dataloader is working locally # Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out @@ -353,7 +407,8 @@ def build_text_dataloader( tokenizer_kwargs = {'model_max_length': args.max_seq_len} tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - loader = build_text_dataloader(cfg, tokenizer, device_batch_size) + loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader + assert isinstance(loader, DataLoader) assert isinstance(loader.dataset, StreamingTextDataset) tokenizer = loader.dataset.tokenizer diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 6495eccf65..656b6d52a6 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -3,22 +3,27 @@ import contextlib import os import pathlib +import random import shutil import sys import tempfile from argparse import Namespace from typing import Optional +from unittest.mock import MagicMock import pytest import torch +import transformers from composer.utils import dist, using_torch_2 +from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import MDSWriter from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, - build_text_dataloader) + build_text_dataloader, + get_tokens_per_batch_func) from llmfoundry.utils.builders import build_tokenizer # Add repo root to path so we can import scripts and test it @@ -137,7 +142,7 @@ def test_correct_padding(tokenizer_name: str, test_cfg.eval_loader, tokenizer, batch_size, - ) + ).dataloader batch = next(iter(eval_loader)) assert batch['input_ids'].shape == torch.Size([batch_size, 2048]) @@ -228,7 +233,7 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, tokenizer_kwargs={'model_max_length': max_seq_len}) loader = build_text_denoising_dataloader(cfg, tokenizer, - device_batch_size) + device_batch_size).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -287,7 +292,8 @@ def test_finetuning_dataloader(decoder_only_format: bool, else: expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] - loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -541,7 +547,8 @@ def test_malformed_data( match='Unable to tokenize example') with error_context: - dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + dl = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader if not add_bad_data_error: # +5 because we added samples with just bos/eos in each of prompt/response @@ -552,3 +559,175 @@ def test_malformed_data( actual_num_batches += 1 assert actual_num_batches == expected_num_batches + + +@pytest.mark.parametrize('pad_token_id', [0, 100, 1000]) +@pytest.mark.parametrize('batch_size', [1, 8, 16]) +@pytest.mark.parametrize('model_max_length', [1024, 2048]) +@pytest.mark.parametrize('padding_side', ['left', 'right']) +@pytest.mark.parametrize('add_decoder_input_ids', [True, False]) +def test_token_counting_func(pad_token_id: int, batch_size: int, + model_max_length: int, padding_side: str, + add_decoder_input_ids: bool): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint(1, model_max_length) + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, padding=True, return_tensors='pt') + + if add_decoder_input_ids: + decoder_batch_strings = [] + decoder_expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint(1, model_max_length) + decoder_batch_strings.append(' '.join(['hello'] * sample_length)) + decoder_expected_token_count += sample_length + expected_token_count += sample_length + batch_tokenized['decoder_input_ids'] = gptt( + decoder_batch_strings, padding=True, + return_tensors='pt')['input_ids'] + + token_counting_func = get_tokens_per_batch_func( + pad_token_id, decoder_only=not add_decoder_input_ids) + + actual_token_count = token_counting_func(batch_tokenized) + + assert actual_token_count == expected_token_count + + +@pytest.mark.parametrize( + 'dataloader_type', + ['finetuning-hf', 'finetuning-streaming', 'denoising', 'text']) +@pytest.mark.parametrize('pad_token_id', [100, None]) +@pytest.mark.parametrize('batch_size', [1, 8]) +@pytest.mark.parametrize('model_max_length', [1024]) +@pytest.mark.parametrize('padding_side', ['left']) +def test_token_counting_func_dataloader_setting( + dataloader_type: str, pad_token_id: Optional[int], batch_size: int, + model_max_length: int, padding_side: str, + monkeypatch: pytest.MonkeyPatch): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint( + 1, + model_max_length) if pad_token_id is not None else model_max_length + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, + padding=True if pad_token_id is not None else False, + return_tensors='pt') + + if dataloader_type == 'denoising': + batch_tokenized['decoder_input_ids'] = batch_tokenized[ + 'input_ids'].clone() + expected_token_count *= 2 + + common_args = { + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None if using_torch_2() else 2, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + if dataloader_type == 'finetuning-hf': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'finetuning-streaming': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'remote': 'dummy-path', + 'local': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'text': + cfg = DictConfig({ + 'name': 'text', + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'shuffle': True, + 'shuffle_seed': 0, + }, + **common_args + }) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + dl = build_text_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'denoising': + cfg = DictConfig({ + 'name': 'text_denoising', + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'val_xsmall', + 'shuffle': False, + 'max_seq_len': model_max_length, + 'packing_ratio': None, + 'predownload': 1000, + 'keep_zip': False, + 'num_workers': None + }, + 'mixture_of_denoisers': { + 'decoder_only_format': False, + 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]], + 'sequence_mask_ratios': 0.25, + }, + **common_args + }) + monkeypatch.setattr('llmfoundry.data.denoising.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + dl = build_text_denoising_dataloader(cfg, gptt, batch_size) + else: + raise NotImplementedError() + + cfg = om.create(cfg) + + actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) + + assert actual_token_count == expected_token_count