Skip to content

Commit

Permalink
Convert to DataSpec and add token counts that include padding (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 17, 2023
1 parent aecadc9 commit 4fa2dd8
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 17 deletions.
18 changes: 14 additions & 4 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
Expand Down Expand Up @@ -353,7 +355,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader[Dict]:
) -> DataSpec:
"""Constructor function for a Mixture of Denoisers dataloader.
This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -506,7 +508,7 @@ def build_text_denoising_dataloader(
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -518,6 +520,12 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id,
decoder_only=cfg.mixture_of_denoisers.decoder_only_format)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def noise_token_sequence(
example: Union[torch.Tensor, Mapping[str, Any]],
Expand Down Expand Up @@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)

print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
Expand Down
16 changes: 12 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand All @@ -14,6 +15,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand All @@ -23,7 +25,7 @@

def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataLoader:
device_batch_size: int) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.
The underlying dataset can be built through one of two code paths:
Expand Down Expand Up @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand Down Expand Up @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

assert dataset is not None
return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
Expand All @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -442,7 +449,8 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

device_batch_size = 2
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
dataloader = build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader

packing = cfg.dataset.get('packing_ratio') is not None

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)
max(raw_batch_sizes) * 100).dataloader

# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
Expand Down
61 changes: 58 additions & 3 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch
import transformers
from composer.core.data_spec import DataSpec
from composer.core.types import Batch
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
Expand Down Expand Up @@ -237,7 +239,7 @@ def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataLoader:
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -281,7 +283,7 @@ def build_text_dataloader(
eos_token_id=eos_token_id,
bos_token_id=bos_token_id)

return DataLoader(
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
Expand All @@ -293,6 +295,58 @@ def build_text_dataloader(
timeout=cfg.get('timeout', 0),
)

# If we pretokenized, we may not have padding, in which case the
# tokenizer may not have a pad_token_id. In this case, we can
# just use the default token counting function. This is correct
# because we do not support training on pretokenized data with padding,
# and if tokenizing on the fly, we require that the tokenizer has a pad token.
token_counting_func = None
if tokenizer.pad_token_id is not None:
token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def get_tokens_per_batch_func(pad_token_id: int,
decoder_only: bool = True
) -> Callable[[Batch], int]:
"""Returns a callable that counts the number of tokens in a batch.
Args:
pad_token_id (int): The id of the padding token.
decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only)
or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``.
Returns:
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'input_ids' not in batch:
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an input_ids key'
)

if not decoder_only and 'decoder_input_ids' not in batch:
raise ValueError(
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key'
)

# Count number of non padding tokens in batch
input_ids_tokens = int(
torch.sum(batch['input_ids'] != pad_token_id).item())

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_input_ids'] != pad_token_id).item())

return input_ids_tokens + decoder_input_ids_tokens

return get_num_samples_in_batch


# Helpful to test if your dataloader is working locally
# Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
Expand Down Expand Up @@ -353,7 +407,8 @@ def build_text_dataloader(
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer

Expand Down
Loading

0 comments on commit 4fa2dd8

Please sign in to comment.