From ca8e6b5cbb5da78d688ca1862e69f4dc948d866f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 4 Nov 2023 19:40:53 -0700 Subject: [PATCH] Add support for auto packing ratio (#683) --- llmfoundry/data/__init__.py | 2 + llmfoundry/data/dataloader.py | 44 +++ llmfoundry/data/denoising.py | 16 +- llmfoundry/data/finetuning/dataloader.py | 50 ++-- llmfoundry/data/packing.py | 277 ++++++++++++------ mcli/mcli-llama2-finetune.yaml | 5 +- scripts/misc/profile_packing.py | 100 +++++++ .../mpt-7b-arc-easy--gpu.yaml | 5 +- scripts/train/train.py | 29 +- .../yamls/finetune/1b_local_data_sft.yaml | 5 +- .../train/yamls/finetune/7b_dolly_sft.yaml | 5 +- .../yamls/finetune/mpt-7b_dolly_sft.yaml | 5 +- tests/test_dataloader.py | 7 +- tests/test_packing.py | 191 ++++++++++++ 14 files changed, 587 insertions(+), 154 deletions(-) create mode 100644 llmfoundry/data/dataloader.py create mode 100644 scripts/misc/profile_packing.py create mode 100644 tests/test_packing.py diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index c997c865dd..8da436b9b1 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.data.denoising import (MixtureOfDenoisersCollator, build_text_denoising_dataloader) from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator, @@ -18,4 +19,5 @@ 'build_text_dataloader', 'NoConcatDataset', 'ConcatTokensDataset', + 'build_dataloader', ] diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py new file mode 100644 index 0000000000..12741717be --- /dev/null +++ b/llmfoundry/data/dataloader.py @@ -0,0 +1,44 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Dataloader builder utilities.""" + +from composer import DataSpec +from omegaconf import DictConfig +from transformers import PreTrainedTokenizerBase + +from llmfoundry.data.denoising import build_text_denoising_dataloader +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.text_data import build_text_dataloader + + +def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int) -> DataSpec: + """Builds a dataloader from a config. + + Args: + cfg (DictConfig): An omegaconf dictionary used to configure the loader. + tokenizer (PreTrainedTokenizerBase): The tokenizer that the model will use. + device_batch_size (int): The size of the batches (number of examples) + that the dataloader will produce. + """ + if cfg.name == 'text': + return build_text_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + elif cfg.name == 'text_denoising': + return build_text_denoising_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + elif cfg.name == 'finetuning': + return build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + else: + raise ValueError(f'Not sure how to build dataloader with config: {cfg}') diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index bc41945076..7d497b4efd 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator from llmfoundry.data.text_data import (StreamingTextDataset, get_tokens_per_batch_func) from llmfoundry.models import utils @@ -375,19 +375,25 @@ def build_text_denoising_dataloader( cfg.dataset.max_seq_len (int): The maximum length of sequences in the batch. See :class:`MixtureOfDenoisersCollator` docstring for details. - cfg.dataset.packing_ratio (float, optional): If provided, this invokes + cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes a collator wrapper that packs device_batch_size*packing_ratio raw examples into device_batch_size packed examples. This helps minimize padding while preserving sequence integrity. This adds `sequence_id` to the batch, which indicates which unique sequence each token belongs to. + + If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with + zero waste is selected. + In practice, this may result in > 0 waste because profiling is done on only a portion + of the dataset. + Note: Using this feature will not change device_batch_size but it will determine the number of raw examples consumed by the dataloader per batch. Some examples may be discarded if they do not fit when packing. Select packing_ratio **carefully** based on the dataset statistics, max_seq_len, and tolerance for discarding samples! - The packing code in `./packing.py` provides a script that can help + The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. See :class:`StreamingTextDataset` for info on other standard config options within `cfg.dataset`. @@ -419,7 +425,7 @@ def build_text_denoising_dataloader( that the dataloader will produce. Note: - You can run the script inside `./packing.py` to quickly test the + You can use the script `scripts/misc/profile_packing.py` to quickly test the padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ @@ -492,7 +498,7 @@ def build_text_denoising_dataloader( raise NotImplementedError( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( + collate_fn = BinPackCollator( collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=cfg.dataset.max_seq_len, diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 2dde563ac6..6e988ac149 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -14,7 +14,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -74,20 +74,26 @@ def build_finetuning_dataloader(cfg: DictConfig, cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow the collator to trim padding. See :class:`Seq2SeqFinetuningCollator` docstring for details. Default: ``False``. - cfg.dataset.packing_ratio (float, optional): If provided, this invokes - a collator wrapper that packs `device_batch_size*packing_ratio` - raw examples into `device_batch_size` packed examples. This helps + cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes + a collator wrapper that packs device_batch_size*packing_ratio + raw examples into device_batch_size packed examples. This helps minimize padding while preserving sequence integrity. This adds `sequence_id` to the batch, which indicates which unique sequence each token belongs to. + + If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with + zero waste is selected. + In practice, this may result in > 0 waste because profiling is done on only a portion + of the dataset. + Note: Using this feature will not change device_batch_size but it will determine the number of raw examples consumed by the dataloader per batch. Some examples may be discarded if they do not fit when packing. - Select `packing_ratio` **carefully** based on the dataset - statistics, `max_seq_len`, and tolerance for discarding samples! - The packing code in `../packing.py` provides a script that can help - you choose the best `packing_ratio`. + Select packing_ratio **carefully** based on the dataset + statistics, max_seq_len, and tolerance for discarding samples! + The script `scripts/misc/profile_packing.py` can help + you choose the best packing_ratio. cfg.dataset.shuffle (bool): Whether to shuffle the dataset. ___ See :class:`StreamingFinetuningDataset` for info on other standard config @@ -106,7 +112,7 @@ def build_finetuning_dataloader(cfg: DictConfig, A pytorch dataloader Note: - You can run the script inside `../packing.py` to quickly test the + You can run the script inside `scripts/misc/profile_packing.py` to quickly test the padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ @@ -143,7 +149,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) + cfg, tokenizer, device_batch_size) dl = DataLoader( dataset, @@ -174,7 +180,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) + cfg, tokenizer, device_batch_size) if cfg.drop_last: world_size = dist.get_world_size() @@ -367,25 +373,33 @@ def _build_hf_dataset_from_remote( def _build_collate_fn( - dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int -) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]: +) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: + dataset_cfg = dataloader_cfg.dataset + max_seq_len = dataset_cfg.max_seq_len + collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, - max_seq_len=dataset_cfg.max_seq_len, + max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), ) packing_ratio = dataset_cfg.get('packing_ratio') + max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep') if packing_ratio is None: - if dataset_cfg.get('max_leftover_bins_to_keep') is not None: + if max_leftover_bins_to_keep is not None: raise ValueError( 'dataset.max_leftover_bins_to_keep has been defined, ' +\ 'but dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') return collate_fn, device_batch_size + if packing_ratio == 'auto': + packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, + device_batch_size) + if packing_ratio == 1.0: return collate_fn, device_batch_size elif packing_ratio < 1.0: @@ -396,13 +410,13 @@ def _build_collate_fn( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( + collate_fn = BinPackCollator( collator=collate_fn, target_batch_size=device_batch_size, - max_seq_len=dataset_cfg.max_seq_len, + max_seq_len=max_seq_len, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, - max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'), + max_leftover_bins_to_keep=max_leftover_bins_to_keep, ) n_examples_to_pack = int(device_batch_size * packing_ratio) return collate_fn, n_examples_to_pack diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 1532de276e..1ae9efcce5 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch @@ -10,7 +9,7 @@ from transformers import PreTrainedTokenizerBase -class BinPackWrapper: +class BinPackCollator: """Utility collator for packing to reduce padding.""" def __init__(self, @@ -33,13 +32,10 @@ def __init__(self, if self.pad_token_id < 0: raise ValueError(f'{pad_token_id=} must be >=0.') - if max_leftover_bins_to_keep is None: - self.max_leftover_bins_to_keep = int(10 * self.out_size) - elif max_leftover_bins_to_keep < 0: + if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0: raise ValueError( f'{max_leftover_bins_to_keep=} must be >=0 or None.') - else: - self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep) + self.max_leftover_bins_to_keep = max_leftover_bins_to_keep self.n_packed_tokens = 0 self.n_total_tokens = 0 @@ -60,7 +56,9 @@ def __call__( self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) + return self.pack(batch) + def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: assert 'attention_mask' in batch assert 'input_ids' in batch @@ -75,12 +73,12 @@ def __call__( # Cut everything down to size sizes, trimmed_examples = [], [] for idx in range(batch['attention_mask'].shape[0]): - size, trimmed_example = extract_trim_batch_idx(batch, idx) + size, trimmed_example = _extract_trim_batch_idx(batch, idx) sizes.append(size) trimmed_examples.append(trimmed_example) # Apply our CS 101 bin packing algorithm. - packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing( + packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing( sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, @@ -93,15 +91,15 @@ def __call__( self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] # Re-pad to max_seq_len and batch - batch = repad(packed_examples, - max_seq_len=self.max_seq_len, - pad_token_id=self.pad_token_id, - padding_side=self.padding_side) + batch = _repad(packed_examples, + max_seq_len=self.max_seq_len, + pad_token_id=self.pad_token_id, + padding_side=self.padding_side) return batch -def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], - idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: +def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], + idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: example = {k: v[idx] for k, v in batch.items()} keep = example['attention_mask'] == 1 @@ -112,7 +110,7 @@ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], return size, trim_example -def combine_in_place( +def _combine_in_place( example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if 'labels' in add_on: @@ -129,7 +127,7 @@ def combine_in_place( return example -def first_fit_bin_packing( +def _first_fit_bin_packing( sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]] ) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ @@ -194,7 +192,7 @@ def first_fit_bin_packing( if bins[bidx][0] + size <= max_bin_size: bin_size, packed_example = bins.pop(bidx) bin_size = bin_size + size - packed_example = combine_in_place(packed_example, example) + packed_example = _combine_in_place(packed_example, example) bins.append((bin_size, packed_example)) added = True break @@ -225,8 +223,8 @@ def first_fit_bin_packing( bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] -def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, - pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: +def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, + pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: def pad_tensor(tensor: torch.Tensor, pad_value: int): if len(tensor) == max_seq_len: @@ -260,14 +258,168 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int): return batch +def auto_packing_ratio(dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, + num_packing_ratios: int = 20) -> float: + """Find a packing ratio that minimizes padding with zero waste. + + By packing examples, we can increase training efficiency, training on more data with less batches. + However, in practice, the selected packing_ratio may produce some waste because profiling is done on only + a subset of the dataset. + + We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to + num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive. + When a packing_ratio with non-zero waste is found, we stop and select the previous ratio, + which has zero waste. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + device_batch_size (int): The size of the batches (number of examples) per device. + num_packing_ratio (int): The number of packing ratios to try. + + Returns: + A packing ratio that minimizes padding while maintaining zero waste. + """ + from composer.utils import dist, get_device, reproducibility + + # Stash the rng state to restore later. + rng_state = reproducibility.get_rng_state() + # Set the seed so that auto packing is deterministic. + reproducibility.seed_all(0) + + min_ratio = 1 + max_ratio = dataloader_cfg.dataset.max_seq_len / 100 + profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, + max_ratio, num_packing_ratios, + device_batch_size) + + # Obtain the maximum packing_ratio/minimum padding that has no waste. + # profiling_results are sorted from smallest to largest packing_ratio. + packing_ratio = 1 + for packing_ratio_candidate, _, waste in profiling_results: + if waste > 0: + break + packing_ratio = packing_ratio_candidate + + # Select the minimum packing ratio across all ranks. + if dist.is_available() and dist.is_initialized(): + device = get_device(None) + packing_ratio_tensor = device.tensor_to_device( + torch.tensor(packing_ratio)) + dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN') + packing_ratio = packing_ratio_tensor.item() + + # Restore rng state. + reproducibility.load_rng_state(rng_state) + + return packing_ratio + + +def profile_packing( + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + min_ratio: float, max_ratio: float, num_packing_ratios: int, + device_batch_size: int) -> Iterable[Tuple[float, float, float]]: + """Generator function that profiles example packing across packing ratios. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + min_ratio (float): Smallest packing_ratio to test. Must be >=1. + max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`. + num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try. + device_batch_size (int): The size of the batches (number of examples) per device. + + Returns: + An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio. + """ + import copy + + from llmfoundry.data.dataloader import build_dataloader + + max_seq_len = dataloader_cfg.dataset.get('max_seq_len') + max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', + None) + + # Turn off packing for the dataloader (we want raw, pre-packed examples) + dataloader_cfg = copy.deepcopy(dataloader_cfg) + dataloader_cfg.dataset.packing_ratio = None + dataloader_cfg.drop_last = False + dataloader_cfg.num_workers = 0 + dataloader_cfg.prefetch_factor = None + + # Determine the packing_ratio values we'll try + packing_ratios, raw_batch_sizes = [], [] + for packing_ratio in np.linspace(min_ratio, + max_ratio, + num_packing_ratios, + endpoint=True): + packing_ratio = np.round(10 * packing_ratio) / 10 + raw_batch_size = int(packing_ratio * device_batch_size) + if raw_batch_size not in raw_batch_sizes: + packing_ratios.append(packing_ratio) + raw_batch_sizes.append(raw_batch_size) + + n_profile_examples = max(raw_batch_sizes) * 100 + + train_dataspec = build_dataloader(dataloader_cfg, tokenizer, + n_profile_examples) + train_dataloader = train_dataspec.dataloader + + # Get a bunch of raw examples + big_batch = next(iter(train_dataloader)) + + def split_big_batch(raw_batch_size: int) -> List: + input_ids = big_batch['input_ids'].split(raw_batch_size) + batches = [{'input_ids': x} for x in input_ids] + + for key in big_batch.keys(): + if key == 'input_ids': + continue + for idx, split in enumerate(big_batch[key].split(raw_batch_size)): + batches[idx].update({key: split}) + return batches + + def profile(raw_batch_size: int) -> Tuple[float, float]: + packer = BinPackCollator( + collator=lambda x: x, + target_batch_size=device_batch_size, + max_seq_len=max_seq_len, + pad_token_id=0, # <-- Doesn't need to be correct for profiling + padding_side='left', # <-- Doesn't need to be correct for profiling + max_leftover_bins_to_keep=max_leftovers_to_keep) + + # Simulate feeding the packing collator a bunch of data + for batch in split_big_batch(raw_batch_size): + if batch['input_ids'].shape[0] < device_batch_size: + continue + _ = packer.pack(batch) + + # Return the padding / waste stats over that bunch of data + padding_percent = 100 * (1 - packer.efficiency) + waste_percent = 100 * packer.waste + return padding_percent, waste_percent + + for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): + padding, waste = profile(raw_batch_size) + yield (packing_ratio, padding, waste) + + if __name__ == '__main__': + + import warnings + + warnings.warn( + DeprecationWarning( + 'Please use scripts/misc/profile_packing.py to profile packing.' + + 'This script will be removed in later releases.')) + + import os from argparse import ArgumentParser, Namespace from omegaconf import OmegaConf as om - from llmfoundry import (build_finetuning_dataloader, - build_text_denoising_dataloader) - from llmfoundry.data import build_text_dataloader from llmfoundry.utils import build_tokenizer def parse_args() -> Namespace: @@ -296,7 +448,7 @@ def parse_args() -> Namespace: parser.add_argument( '--num-packing-ratios', type=int, - default=10, + default=20, help= 'Number of packing_ratio values (spaced between `min` and `max) to try.' ) @@ -316,20 +468,6 @@ def parse_args() -> Namespace: raise ValueError('`num_packing_ratios` must be a positive integer.') return args - def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): - if cfg.name == 'text': - return build_text_dataloader(cfg, tokenizer, device_batch_size) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader(cfg, tokenizer, - device_batch_size) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader(cfg, tokenizer, - device_batch_size) - else: - raise ValueError( - f'Not sure how to build dataloader with config: {cfg}') - args = parse_args() with open(args.yaml_path) as f: @@ -339,26 +477,11 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, cfg = om.create(cfg) device_batch_size = cfg.global_train_batch_size // args.num_devices - # Determine the packing_ratio values we'll try - packing_ratios, raw_batch_sizes = [], [] - for packing_ratio in np.linspace(args.min, - args.max, - args.num_packing_ratios, - endpoint=True): - packing_ratio = np.round(10 * packing_ratio) / 10 - raw_batch_size = int(packing_ratio * device_batch_size) - if raw_batch_size not in raw_batch_sizes: - packing_ratios.append(packing_ratio) - raw_batch_sizes.append(raw_batch_size) - # Fetch a bunch of raw examples once, which we'll re-use if 'train_loader' not in cfg: raise ValueError('config must define train_loader') dataloader_cfg = cfg.train_loader - max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', - None) - # build tokenizer if 'tokenizer' not in cfg: raise ValueError('config must define tokenizer') @@ -367,57 +490,19 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if not isinstance(resolved_tokenizer_cfg, Dict): raise ValueError( 'tokenizer config needs to be resolved by omegaconf into a Dict.') - tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg + tokenizer_cfg = resolved_tokenizer_cfg tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - # Turn off packing for the dataloader (we want raw, pre-packed examples) - dataloader_cfg.dataset.packing_ratio = None - dataloader_cfg.dataset.max_leftovers_to_keep = None - train_dataloader = build_dataloader(dataloader_cfg, tokenizer, - max(raw_batch_sizes) * 100).dataloader - - # Get a bunch of raw examples - big_batch = next(iter(train_dataloader)) - - def split_big_batch(raw_batch_size: int) -> List: - input_ids = big_batch['input_ids'].split(raw_batch_size) - batches = [{'input_ids': x} for x in input_ids] - - for key in big_batch.keys(): - if key == 'input_ids': - continue - for idx, split in enumerate(big_batch[key].split(raw_batch_size)): - batches[idx].update({key: split}) - return batches - - def profile_packing(raw_batch_size: int) -> Tuple[float, float]: - packer = BinPackWrapper( - collator=lambda x: x, - target_batch_size=device_batch_size, - max_seq_len=dataloader_cfg.dataset.max_seq_len, - pad_token_id=0, # <-- Doesn't need to be correct for profiling - padding_side='left', # <-- Doesn't need to be correct for profiling - max_leftover_bins_to_keep=max_leftovers_to_keep) - - # Simulate feeding the packing collator a bunch of data - for batch in split_big_batch(raw_batch_size): - if batch['input_ids'].shape[0] < device_batch_size: - continue - _ = packer(batch) - - # Return the padding / waste stats over that bunch of data - padding_percent = 100 * (1 - packer.efficiency) - waste_percent = 100 * packer.waste - return padding_percent, waste_percent + results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, + args.num_packing_ratios, device_batch_size) header = '\n\n\n packing_ratio | % PADDING | % WASTE' fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' print(header) print('-' * len(header)) - for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): - padding, waste = profile_packing(raw_batch_size) + for packing_ratio, padding, waste in results: print(fstr.format(packing_ratio, padding, waste)) diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index ae8f57abb6..93d46f57e3 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -56,7 +56,10 @@ parameters: allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/misc/profile_packing.py b/scripts/misc/profile_packing.py new file mode 100644 index 0000000000..51841d669e --- /dev/null +++ b/scripts/misc/profile_packing.py @@ -0,0 +1,100 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to profile example packing.""" +import os +from typing import Dict + +from llmfoundry.data.packing import profile_packing + +if __name__ == '__main__': + from argparse import ArgumentParser, Namespace + + from omegaconf import OmegaConf as om + + from llmfoundry.utils import build_tokenizer + + def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description= + 'Profile packing_ratio choices for a particular workload.') + parser.add_argument( + '--yaml-path', + type=str, + required=True, + help='Path to the YAML that defines the workload to profile.') + parser.add_argument('--num-devices', + type=int, + default=None, + help='How many devices your run will use.') + parser.add_argument('--min', + type=float, + required=True, + help='Smallest packing_ratio to test. Must be >=1.') + parser.add_argument( + '--max', + type=float, + required=True, + help='Largest packing_ratio to test. Must be larger than `min`.') + parser.add_argument( + '--num-packing-ratios', + type=int, + default=20, + help= + 'Number of packing_ratio values (spaced between `min` and `max) to try.' + ) + + args = parser.parse_args() + + if not os.path.isfile(args.yaml_path): + raise FileNotFoundError( + '`yaml_path` does not correspond to any existing file.') + if args.num_devices < 1: + raise ValueError('`num_devices` must be a positive integer.') + if args.min < 1.0: + raise ValueError('`min` must be >=1.0.') + if args.max < args.min: + raise ValueError('`max` cannot be less than `min`.') + if args.num_packing_ratios < 1: + raise ValueError('`num_packing_ratios` must be a positive integer.') + return args + + args = parse_args() + + with open(args.yaml_path) as f: + cfg = om.load(f) + if 'parameters' in cfg: + cfg = om.to_container(cfg.parameters) + cfg = om.create(cfg) + device_batch_size = cfg.global_train_batch_size // args.num_devices + + # Fetch a bunch of raw examples once, which we'll re-use + if 'train_loader' not in cfg: + raise ValueError('config must define train_loader') + dataloader_cfg = cfg.train_loader + + # build tokenizer + if 'tokenizer' not in cfg: + raise ValueError('config must define tokenizer') + + resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) + if not isinstance(resolved_tokenizer_cfg, Dict): + raise ValueError( + 'tokenizer config needs to be resolved by omegaconf into a Dict.') + tokenizer_cfg = resolved_tokenizer_cfg + + tokenizer_name = tokenizer_cfg['name'] + tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, + args.num_packing_ratios, device_batch_size) + + header = '\n\n\n packing_ratio | % PADDING | % WASTE' + fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' + + print(header) + print('-' * len(header)) + for packing_ratio, padding, waste in results: + print(fstr.format(packing_ratio, padding, waste)) diff --git a/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml b/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml index 2c3fb11496..ed2e9fcac0 100644 --- a/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml +++ b/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml @@ -41,7 +41,10 @@ train_loader: shuffle: true max_seq_len: ${max_seq_len} decoder_only_format: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/train.py b/scripts/train/train.py index e29f2c9a47..60ee55955e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -24,9 +24,8 @@ from transformers import PreTrainedTokenizerBase from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, - MPTForCausalLM, build_finetuning_dataloader, - build_text_denoising_dataloader) -from llmfoundry.data.text_data import build_text_dataloader + MPTForCausalLM) +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.utils.builders import (build_algorithm, build_callback, build_icl_data_and_gauntlet, build_logger, build_optimizer, @@ -169,30 +168,6 @@ def print_trainable_parameters(model: torch.nn.Module) -> None: ) -def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): - if cfg.name == 'text': - return build_text_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - else: - raise ValueError(f'Not sure how to build dataloader with config: {cfg}') - - def main(cfg: DictConfig) -> Trainer: # Filter deprecation warning from torch internal usage warnings.filterwarnings( diff --git a/scripts/train/yamls/finetune/1b_local_data_sft.yaml b/scripts/train/yamls/finetune/1b_local_data_sft.yaml index 45dca2f1e0..d6f72b0c8e 100644 --- a/scripts/train/yamls/finetune/1b_local_data_sft.yaml +++ b/scripts/train/yamls/finetune/1b_local_data_sft.yaml @@ -49,7 +49,10 @@ train_loader: &train_loader allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/yamls/finetune/7b_dolly_sft.yaml b/scripts/train/yamls/finetune/7b_dolly_sft.yaml index 6483dd31f5..c5813235d9 100644 --- a/scripts/train/yamls/finetune/7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/7b_dolly_sft.yaml @@ -41,7 +41,10 @@ train_loader: allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml index 9686317bef..2f23d8e55a 100644 --- a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml @@ -31,7 +31,10 @@ train_loader: max_seq_len: ${max_seq_len} allow_pad_trimming: false decoder_only_format: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 656b6d52a6..2080ec32ec 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -8,7 +8,7 @@ import sys import tempfile from argparse import Namespace -from typing import Optional +from typing import Literal, Optional, Union from unittest.mock import MagicMock import pytest @@ -248,10 +248,11 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) -@pytest.mark.parametrize('packing_ratio', [10.0, None]) +@pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto']) def test_finetuning_dataloader(decoder_only_format: bool, allow_pad_trimming: bool, - packing_ratio: Optional[float]): + packing_ratio: Optional[Union[float, + Literal['auto']]]): # Use the datasets just built in the last test tokenizer_name = 'gpt2' if decoder_only_format else 't5-base' max_seq_len = 2048 if decoder_only_format else 1024 diff --git a/tests/test_packing.py b/tests/test_packing.py new file mode 100644 index 0000000000..cbeca8b7b1 --- /dev/null +++ b/tests/test_packing.py @@ -0,0 +1,191 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List +from unittest.mock import Mock, patch + +import pytest +import torch +from composer.utils import dist, reproducibility +from omegaconf import DictConfig +from pytest import approx +from torch.utils.data import DataLoader + +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio +from llmfoundry.utils.builders import build_tokenizer + + +def _data_to_batch(data: List[List[int]], max_seq_len: int, + pad_token_id: int) -> Dict[str, torch.Tensor]: + """Helper function to create a proper batch of data.""" + input_ids = torch.stack([ + torch.tensor(d + [pad_token_id] * (max_seq_len - len(d))) for d in data + ]) + + attention_mask = torch.stack([ + torch.tensor([1] * len(d) + [pad_token_id] * (max_seq_len - len(d))) + for d in data + ]) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def test_packing(): + """Tests that packing works for a single batch.""" + pad_token_id = 0 + max_seq_len = 5 + packer = BinPackCollator(collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right') + + batch = _data_to_batch([ + [1], + [2] * 2, + [4] * 4, + [3] * 3, + ], max_seq_len, pad_token_id) + + packed_samples = packer.pack(batch) + + assert torch.equal(packed_samples['input_ids'], + torch.Tensor([[3, 3, 3, 2, 2], [4, 4, 4, 4, 1]])) + assert torch.all(packed_samples['attention_mask'] == 1) + + +def test_packing_with_leftovers(): + """Tests that packing handles leftovers and computes waste correctly.""" + pad_token_id = 0 + max_seq_len = 5 + packer = BinPackCollator(collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right') + + batch = _data_to_batch([ + [1], + [2] * 2, + [4] * 4, + [4] * 4, + ], max_seq_len, pad_token_id) + + packed_batch = packer.pack(batch) + + assert torch.equal(packed_batch['input_ids'], + torch.Tensor([[4, 4, 4, 4, 1], [4, 4, 4, 4, 0]])) + assert torch.equal(packed_batch['attention_mask'], + torch.Tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + + # Check leftovers and waste. + assert len(packer._leftover_bins) == 1 + leftover_size, leftover = packer._leftover_bins[0] + assert leftover_size == 2 + assert torch.equal(leftover['input_ids'], torch.Tensor([2, 2])) + assert torch.equal(leftover['attention_mask'], torch.Tensor([1, 1])) + assert packer.waste == approx(2 / 11) # 2 tokens wasted of 11 tokens total + + # Ensure that leftovers are used in the next batch if possible. + batch = _data_to_batch([[1]], max_seq_len, pad_token_id) + packed_batch = packer.pack(batch) + assert torch.equal(packed_batch['input_ids'], + torch.Tensor([[2, 2, 0, 0, 0], [1, 0, 0, 0, 0]])) + assert torch.equal(packed_batch['attention_mask'], + torch.Tensor([[1, 1, 0, 0, 0], [1, 0, 0, 0, 0]])) + + +@patch('llmfoundry.data.packing.profile_packing') +def test_auto_packing(profile_packing: Mock): + """Tests that auto packing selects the highest packing ratio with zero. + + waste. + """ + # List of tuples of packing_ratio, padding, waste, sorted by packing ratio + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), (3, .7, .5)] + + packing_ratio = auto_packing_ratio( + dataloader_cfg=DictConfig({'dataset': { + 'max_seq_len': 2048 + }}), + tokenizer=None, + device_batch_size=1, + ) # Dummy values, profiling results are already set. + + # auto packing ratio should choose 2 because packing ratio is maximized while waste is 0. + assert packing_ratio == 2 + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@patch('llmfoundry.data.packing.profile_packing') +def test_dist_auto_packing(profile_packing: Mock): + """Tests that auto packing works with world size > 1.""" + dist.initialize_dist('gpu') + + # List of tuples of packing_ratio, padding, waste, sorted by packing ratio + if dist.get_global_rank() == 0: + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), + (3, .7, 0)] # should pick 3 + else: + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), + (3, .7, .5)] # should pick 2 + + packing_ratio = auto_packing_ratio( + dataloader_cfg=DictConfig({'dataset': { + 'max_seq_len': 2048 + }}), + tokenizer=None, + device_batch_size=1, + ) # Dummy values, profiling results are already set. + + # auto packing ratio should choose 2 because it's the minimum between ranks. + assert packing_ratio == 2 + + +@pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) +def test_packing_with_dataloader(packing_ratio: Any): + """Tests that packing works with a dataloader.""" + reproducibility.seed_all(17) + tokenizer = build_tokenizer('gpt2', {}) + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'tatsu-lab/alpaca', + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': packing_ratio, + 'shuffle': False, + }, + 'drop_last': False, + # Need to test with 0 num_workers because the packing collator object + # Gets copied per worker and we cannot check the waste for child processes. + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + }) + + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size=6).dataloader + + assert isinstance(loader, DataLoader) + pack_collator = loader.collate_fn + assert isinstance(pack_collator, BinPackCollator) + + batch_ix = 0 + for _ in loader: + batch_ix += 1 + if batch_ix >= 3: + break + + padding = (1 - pack_collator.efficiency) + if packing_ratio == 'auto': + assert pack_collator.waste == approx(0) + assert padding == approx(0.1197916, rel=.01) + else: + assert pack_collator.waste == approx(0) + assert padding == approx(0.873720, rel=.01)