From 6c412412da951b22bfaabb1d0f0333bc4b31a6d0 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:38:50 -0700 Subject: [PATCH 01/15] Fix HF local module copy contention with a meta init on local rank 0 (#710) --- llmfoundry/models/hf/hf_causal_lm.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index eb90b07045..d52633a09b 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -5,6 +5,7 @@ import logging import os +import warnings from typing import Mapping, Union # required for loading a python model into composer @@ -157,6 +158,24 @@ def __init__(self, om_model_config: Union[DictConfig, if dist.get_local_rank() != 0 and init_device == 'mixed': om_model_config.pretrained = False + # If the HuggingFace model is coming from a local folder, Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the model (on meta device) on local rank + # zero. This will set up the transformers model cache and avoid the future contention. + if dist.get_local_rank() == 0 and os.path.isdir( + om_model_config.pretrained_model_name_or_path): + with init_empty_weights(include_buffers=False): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + AutoModelForCausalLM.from_pretrained( + om_model_config.pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + config=config, + ) + + dist.barrier() + # initialize the model on the correct device if resolved_init_device == 'cpu': if om_model_config.pretrained: From ca8e6b5cbb5da78d688ca1862e69f4dc948d866f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 4 Nov 2023 19:40:53 -0700 Subject: [PATCH 02/15] Add support for auto packing ratio (#683) --- llmfoundry/data/__init__.py | 2 + llmfoundry/data/dataloader.py | 44 +++ llmfoundry/data/denoising.py | 16 +- llmfoundry/data/finetuning/dataloader.py | 50 ++-- llmfoundry/data/packing.py | 277 ++++++++++++------ mcli/mcli-llama2-finetune.yaml | 5 +- scripts/misc/profile_packing.py | 100 +++++++ .../mpt-7b-arc-easy--gpu.yaml | 5 +- scripts/train/train.py | 29 +- .../yamls/finetune/1b_local_data_sft.yaml | 5 +- .../train/yamls/finetune/7b_dolly_sft.yaml | 5 +- .../yamls/finetune/mpt-7b_dolly_sft.yaml | 5 +- tests/test_dataloader.py | 7 +- tests/test_packing.py | 191 ++++++++++++ 14 files changed, 587 insertions(+), 154 deletions(-) create mode 100644 llmfoundry/data/dataloader.py create mode 100644 scripts/misc/profile_packing.py create mode 100644 tests/test_packing.py diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index c997c865dd..8da436b9b1 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.data.denoising import (MixtureOfDenoisersCollator, build_text_denoising_dataloader) from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator, @@ -18,4 +19,5 @@ 'build_text_dataloader', 'NoConcatDataset', 'ConcatTokensDataset', + 'build_dataloader', ] diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py new file mode 100644 index 0000000000..12741717be --- /dev/null +++ b/llmfoundry/data/dataloader.py @@ -0,0 +1,44 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Dataloader builder utilities.""" + +from composer import DataSpec +from omegaconf import DictConfig +from transformers import PreTrainedTokenizerBase + +from llmfoundry.data.denoising import build_text_denoising_dataloader +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.text_data import build_text_dataloader + + +def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int) -> DataSpec: + """Builds a dataloader from a config. + + Args: + cfg (DictConfig): An omegaconf dictionary used to configure the loader. + tokenizer (PreTrainedTokenizerBase): The tokenizer that the model will use. + device_batch_size (int): The size of the batches (number of examples) + that the dataloader will produce. + """ + if cfg.name == 'text': + return build_text_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + elif cfg.name == 'text_denoising': + return build_text_denoising_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + elif cfg.name == 'finetuning': + return build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ) + else: + raise ValueError(f'Not sure how to build dataloader with config: {cfg}') diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index bc41945076..7d497b4efd 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator from llmfoundry.data.text_data import (StreamingTextDataset, get_tokens_per_batch_func) from llmfoundry.models import utils @@ -375,19 +375,25 @@ def build_text_denoising_dataloader( cfg.dataset.max_seq_len (int): The maximum length of sequences in the batch. See :class:`MixtureOfDenoisersCollator` docstring for details. - cfg.dataset.packing_ratio (float, optional): If provided, this invokes + cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes a collator wrapper that packs device_batch_size*packing_ratio raw examples into device_batch_size packed examples. This helps minimize padding while preserving sequence integrity. This adds `sequence_id` to the batch, which indicates which unique sequence each token belongs to. + + If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with + zero waste is selected. + In practice, this may result in > 0 waste because profiling is done on only a portion + of the dataset. + Note: Using this feature will not change device_batch_size but it will determine the number of raw examples consumed by the dataloader per batch. Some examples may be discarded if they do not fit when packing. Select packing_ratio **carefully** based on the dataset statistics, max_seq_len, and tolerance for discarding samples! - The packing code in `./packing.py` provides a script that can help + The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. See :class:`StreamingTextDataset` for info on other standard config options within `cfg.dataset`. @@ -419,7 +425,7 @@ def build_text_denoising_dataloader( that the dataloader will produce. Note: - You can run the script inside `./packing.py` to quickly test the + You can use the script `scripts/misc/profile_packing.py` to quickly test the padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ @@ -492,7 +498,7 @@ def build_text_denoising_dataloader( raise NotImplementedError( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( + collate_fn = BinPackCollator( collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=cfg.dataset.max_seq_len, diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 2dde563ac6..6e988ac149 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -14,7 +14,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor -from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -74,20 +74,26 @@ def build_finetuning_dataloader(cfg: DictConfig, cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow the collator to trim padding. See :class:`Seq2SeqFinetuningCollator` docstring for details. Default: ``False``. - cfg.dataset.packing_ratio (float, optional): If provided, this invokes - a collator wrapper that packs `device_batch_size*packing_ratio` - raw examples into `device_batch_size` packed examples. This helps + cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes + a collator wrapper that packs device_batch_size*packing_ratio + raw examples into device_batch_size packed examples. This helps minimize padding while preserving sequence integrity. This adds `sequence_id` to the batch, which indicates which unique sequence each token belongs to. + + If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with + zero waste is selected. + In practice, this may result in > 0 waste because profiling is done on only a portion + of the dataset. + Note: Using this feature will not change device_batch_size but it will determine the number of raw examples consumed by the dataloader per batch. Some examples may be discarded if they do not fit when packing. - Select `packing_ratio` **carefully** based on the dataset - statistics, `max_seq_len`, and tolerance for discarding samples! - The packing code in `../packing.py` provides a script that can help - you choose the best `packing_ratio`. + Select packing_ratio **carefully** based on the dataset + statistics, max_seq_len, and tolerance for discarding samples! + The script `scripts/misc/profile_packing.py` can help + you choose the best packing_ratio. cfg.dataset.shuffle (bool): Whether to shuffle the dataset. ___ See :class:`StreamingFinetuningDataset` for info on other standard config @@ -106,7 +112,7 @@ def build_finetuning_dataloader(cfg: DictConfig, A pytorch dataloader Note: - You can run the script inside `../packing.py` to quickly test the + You can run the script inside `scripts/misc/profile_packing.py` to quickly test the padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ @@ -143,7 +149,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) + cfg, tokenizer, device_batch_size) dl = DataLoader( dataset, @@ -174,7 +180,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) collate_fn, dataloader_batch_size = _build_collate_fn( - cfg.dataset, tokenizer, device_batch_size) + cfg, tokenizer, device_batch_size) if cfg.drop_last: world_size = dist.get_world_size() @@ -367,25 +373,33 @@ def _build_hf_dataset_from_remote( def _build_collate_fn( - dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int -) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]: +) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: + dataset_cfg = dataloader_cfg.dataset + max_seq_len = dataset_cfg.max_seq_len + collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, - max_seq_len=dataset_cfg.max_seq_len, + max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), ) packing_ratio = dataset_cfg.get('packing_ratio') + max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep') if packing_ratio is None: - if dataset_cfg.get('max_leftover_bins_to_keep') is not None: + if max_leftover_bins_to_keep is not None: raise ValueError( 'dataset.max_leftover_bins_to_keep has been defined, ' +\ 'but dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') return collate_fn, device_batch_size + if packing_ratio == 'auto': + packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, + device_batch_size) + if packing_ratio == 1.0: return collate_fn, device_batch_size elif packing_ratio < 1.0: @@ -396,13 +410,13 @@ def _build_collate_fn( 'On-the-fly packing is currently only supported for decoder-only formats.' ) - collate_fn = BinPackWrapper( + collate_fn = BinPackCollator( collator=collate_fn, target_batch_size=device_batch_size, - max_seq_len=dataset_cfg.max_seq_len, + max_seq_len=max_seq_len, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, - max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'), + max_leftover_bins_to_keep=max_leftover_bins_to_keep, ) n_examples_to_pack = int(device_batch_size * packing_ratio) return collate_fn, n_examples_to_pack diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 1532de276e..1ae9efcce5 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch @@ -10,7 +9,7 @@ from transformers import PreTrainedTokenizerBase -class BinPackWrapper: +class BinPackCollator: """Utility collator for packing to reduce padding.""" def __init__(self, @@ -33,13 +32,10 @@ def __init__(self, if self.pad_token_id < 0: raise ValueError(f'{pad_token_id=} must be >=0.') - if max_leftover_bins_to_keep is None: - self.max_leftover_bins_to_keep = int(10 * self.out_size) - elif max_leftover_bins_to_keep < 0: + if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0: raise ValueError( f'{max_leftover_bins_to_keep=} must be >=0 or None.') - else: - self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep) + self.max_leftover_bins_to_keep = max_leftover_bins_to_keep self.n_packed_tokens = 0 self.n_total_tokens = 0 @@ -60,7 +56,9 @@ def __call__( self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) + return self.pack(batch) + def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: assert 'attention_mask' in batch assert 'input_ids' in batch @@ -75,12 +73,12 @@ def __call__( # Cut everything down to size sizes, trimmed_examples = [], [] for idx in range(batch['attention_mask'].shape[0]): - size, trimmed_example = extract_trim_batch_idx(batch, idx) + size, trimmed_example = _extract_trim_batch_idx(batch, idx) sizes.append(size) trimmed_examples.append(trimmed_example) # Apply our CS 101 bin packing algorithm. - packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing( + packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing( sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, @@ -93,15 +91,15 @@ def __call__( self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] # Re-pad to max_seq_len and batch - batch = repad(packed_examples, - max_seq_len=self.max_seq_len, - pad_token_id=self.pad_token_id, - padding_side=self.padding_side) + batch = _repad(packed_examples, + max_seq_len=self.max_seq_len, + pad_token_id=self.pad_token_id, + padding_side=self.padding_side) return batch -def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], - idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: +def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], + idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: example = {k: v[idx] for k, v in batch.items()} keep = example['attention_mask'] == 1 @@ -112,7 +110,7 @@ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], return size, trim_example -def combine_in_place( +def _combine_in_place( example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if 'labels' in add_on: @@ -129,7 +127,7 @@ def combine_in_place( return example -def first_fit_bin_packing( +def _first_fit_bin_packing( sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]] ) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ @@ -194,7 +192,7 @@ def first_fit_bin_packing( if bins[bidx][0] + size <= max_bin_size: bin_size, packed_example = bins.pop(bidx) bin_size = bin_size + size - packed_example = combine_in_place(packed_example, example) + packed_example = _combine_in_place(packed_example, example) bins.append((bin_size, packed_example)) added = True break @@ -225,8 +223,8 @@ def first_fit_bin_packing( bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] -def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, - pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: +def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, + pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: def pad_tensor(tensor: torch.Tensor, pad_value: int): if len(tensor) == max_seq_len: @@ -260,14 +258,168 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int): return batch +def auto_packing_ratio(dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, + num_packing_ratios: int = 20) -> float: + """Find a packing ratio that minimizes padding with zero waste. + + By packing examples, we can increase training efficiency, training on more data with less batches. + However, in practice, the selected packing_ratio may produce some waste because profiling is done on only + a subset of the dataset. + + We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to + num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive. + When a packing_ratio with non-zero waste is found, we stop and select the previous ratio, + which has zero waste. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + device_batch_size (int): The size of the batches (number of examples) per device. + num_packing_ratio (int): The number of packing ratios to try. + + Returns: + A packing ratio that minimizes padding while maintaining zero waste. + """ + from composer.utils import dist, get_device, reproducibility + + # Stash the rng state to restore later. + rng_state = reproducibility.get_rng_state() + # Set the seed so that auto packing is deterministic. + reproducibility.seed_all(0) + + min_ratio = 1 + max_ratio = dataloader_cfg.dataset.max_seq_len / 100 + profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, + max_ratio, num_packing_ratios, + device_batch_size) + + # Obtain the maximum packing_ratio/minimum padding that has no waste. + # profiling_results are sorted from smallest to largest packing_ratio. + packing_ratio = 1 + for packing_ratio_candidate, _, waste in profiling_results: + if waste > 0: + break + packing_ratio = packing_ratio_candidate + + # Select the minimum packing ratio across all ranks. + if dist.is_available() and dist.is_initialized(): + device = get_device(None) + packing_ratio_tensor = device.tensor_to_device( + torch.tensor(packing_ratio)) + dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN') + packing_ratio = packing_ratio_tensor.item() + + # Restore rng state. + reproducibility.load_rng_state(rng_state) + + return packing_ratio + + +def profile_packing( + dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + min_ratio: float, max_ratio: float, num_packing_ratios: int, + device_batch_size: int) -> Iterable[Tuple[float, float, float]]: + """Generator function that profiles example packing across packing ratios. + + Args: + dataloader_cfg (DictConfig): The dataloader configuration for profiling. + tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. + min_ratio (float): Smallest packing_ratio to test. Must be >=1. + max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`. + num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try. + device_batch_size (int): The size of the batches (number of examples) per device. + + Returns: + An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio. + """ + import copy + + from llmfoundry.data.dataloader import build_dataloader + + max_seq_len = dataloader_cfg.dataset.get('max_seq_len') + max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', + None) + + # Turn off packing for the dataloader (we want raw, pre-packed examples) + dataloader_cfg = copy.deepcopy(dataloader_cfg) + dataloader_cfg.dataset.packing_ratio = None + dataloader_cfg.drop_last = False + dataloader_cfg.num_workers = 0 + dataloader_cfg.prefetch_factor = None + + # Determine the packing_ratio values we'll try + packing_ratios, raw_batch_sizes = [], [] + for packing_ratio in np.linspace(min_ratio, + max_ratio, + num_packing_ratios, + endpoint=True): + packing_ratio = np.round(10 * packing_ratio) / 10 + raw_batch_size = int(packing_ratio * device_batch_size) + if raw_batch_size not in raw_batch_sizes: + packing_ratios.append(packing_ratio) + raw_batch_sizes.append(raw_batch_size) + + n_profile_examples = max(raw_batch_sizes) * 100 + + train_dataspec = build_dataloader(dataloader_cfg, tokenizer, + n_profile_examples) + train_dataloader = train_dataspec.dataloader + + # Get a bunch of raw examples + big_batch = next(iter(train_dataloader)) + + def split_big_batch(raw_batch_size: int) -> List: + input_ids = big_batch['input_ids'].split(raw_batch_size) + batches = [{'input_ids': x} for x in input_ids] + + for key in big_batch.keys(): + if key == 'input_ids': + continue + for idx, split in enumerate(big_batch[key].split(raw_batch_size)): + batches[idx].update({key: split}) + return batches + + def profile(raw_batch_size: int) -> Tuple[float, float]: + packer = BinPackCollator( + collator=lambda x: x, + target_batch_size=device_batch_size, + max_seq_len=max_seq_len, + pad_token_id=0, # <-- Doesn't need to be correct for profiling + padding_side='left', # <-- Doesn't need to be correct for profiling + max_leftover_bins_to_keep=max_leftovers_to_keep) + + # Simulate feeding the packing collator a bunch of data + for batch in split_big_batch(raw_batch_size): + if batch['input_ids'].shape[0] < device_batch_size: + continue + _ = packer.pack(batch) + + # Return the padding / waste stats over that bunch of data + padding_percent = 100 * (1 - packer.efficiency) + waste_percent = 100 * packer.waste + return padding_percent, waste_percent + + for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): + padding, waste = profile(raw_batch_size) + yield (packing_ratio, padding, waste) + + if __name__ == '__main__': + + import warnings + + warnings.warn( + DeprecationWarning( + 'Please use scripts/misc/profile_packing.py to profile packing.' + + 'This script will be removed in later releases.')) + + import os from argparse import ArgumentParser, Namespace from omegaconf import OmegaConf as om - from llmfoundry import (build_finetuning_dataloader, - build_text_denoising_dataloader) - from llmfoundry.data import build_text_dataloader from llmfoundry.utils import build_tokenizer def parse_args() -> Namespace: @@ -296,7 +448,7 @@ def parse_args() -> Namespace: parser.add_argument( '--num-packing-ratios', type=int, - default=10, + default=20, help= 'Number of packing_ratio values (spaced between `min` and `max) to try.' ) @@ -316,20 +468,6 @@ def parse_args() -> Namespace: raise ValueError('`num_packing_ratios` must be a positive integer.') return args - def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): - if cfg.name == 'text': - return build_text_dataloader(cfg, tokenizer, device_batch_size) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader(cfg, tokenizer, - device_batch_size) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader(cfg, tokenizer, - device_batch_size) - else: - raise ValueError( - f'Not sure how to build dataloader with config: {cfg}') - args = parse_args() with open(args.yaml_path) as f: @@ -339,26 +477,11 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, cfg = om.create(cfg) device_batch_size = cfg.global_train_batch_size // args.num_devices - # Determine the packing_ratio values we'll try - packing_ratios, raw_batch_sizes = [], [] - for packing_ratio in np.linspace(args.min, - args.max, - args.num_packing_ratios, - endpoint=True): - packing_ratio = np.round(10 * packing_ratio) / 10 - raw_batch_size = int(packing_ratio * device_batch_size) - if raw_batch_size not in raw_batch_sizes: - packing_ratios.append(packing_ratio) - raw_batch_sizes.append(raw_batch_size) - # Fetch a bunch of raw examples once, which we'll re-use if 'train_loader' not in cfg: raise ValueError('config must define train_loader') dataloader_cfg = cfg.train_loader - max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', - None) - # build tokenizer if 'tokenizer' not in cfg: raise ValueError('config must define tokenizer') @@ -367,57 +490,19 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if not isinstance(resolved_tokenizer_cfg, Dict): raise ValueError( 'tokenizer config needs to be resolved by omegaconf into a Dict.') - tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg + tokenizer_cfg = resolved_tokenizer_cfg tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - # Turn off packing for the dataloader (we want raw, pre-packed examples) - dataloader_cfg.dataset.packing_ratio = None - dataloader_cfg.dataset.max_leftovers_to_keep = None - train_dataloader = build_dataloader(dataloader_cfg, tokenizer, - max(raw_batch_sizes) * 100).dataloader - - # Get a bunch of raw examples - big_batch = next(iter(train_dataloader)) - - def split_big_batch(raw_batch_size: int) -> List: - input_ids = big_batch['input_ids'].split(raw_batch_size) - batches = [{'input_ids': x} for x in input_ids] - - for key in big_batch.keys(): - if key == 'input_ids': - continue - for idx, split in enumerate(big_batch[key].split(raw_batch_size)): - batches[idx].update({key: split}) - return batches - - def profile_packing(raw_batch_size: int) -> Tuple[float, float]: - packer = BinPackWrapper( - collator=lambda x: x, - target_batch_size=device_batch_size, - max_seq_len=dataloader_cfg.dataset.max_seq_len, - pad_token_id=0, # <-- Doesn't need to be correct for profiling - padding_side='left', # <-- Doesn't need to be correct for profiling - max_leftover_bins_to_keep=max_leftovers_to_keep) - - # Simulate feeding the packing collator a bunch of data - for batch in split_big_batch(raw_batch_size): - if batch['input_ids'].shape[0] < device_batch_size: - continue - _ = packer(batch) - - # Return the padding / waste stats over that bunch of data - padding_percent = 100 * (1 - packer.efficiency) - waste_percent = 100 * packer.waste - return padding_percent, waste_percent + results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, + args.num_packing_ratios, device_batch_size) header = '\n\n\n packing_ratio | % PADDING | % WASTE' fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' print(header) print('-' * len(header)) - for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): - padding, waste = profile_packing(raw_batch_size) + for packing_ratio, padding, waste in results: print(fstr.format(packing_ratio, padding, waste)) diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index ae8f57abb6..93d46f57e3 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -56,7 +56,10 @@ parameters: allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/misc/profile_packing.py b/scripts/misc/profile_packing.py new file mode 100644 index 0000000000..51841d669e --- /dev/null +++ b/scripts/misc/profile_packing.py @@ -0,0 +1,100 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to profile example packing.""" +import os +from typing import Dict + +from llmfoundry.data.packing import profile_packing + +if __name__ == '__main__': + from argparse import ArgumentParser, Namespace + + from omegaconf import OmegaConf as om + + from llmfoundry.utils import build_tokenizer + + def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description= + 'Profile packing_ratio choices for a particular workload.') + parser.add_argument( + '--yaml-path', + type=str, + required=True, + help='Path to the YAML that defines the workload to profile.') + parser.add_argument('--num-devices', + type=int, + default=None, + help='How many devices your run will use.') + parser.add_argument('--min', + type=float, + required=True, + help='Smallest packing_ratio to test. Must be >=1.') + parser.add_argument( + '--max', + type=float, + required=True, + help='Largest packing_ratio to test. Must be larger than `min`.') + parser.add_argument( + '--num-packing-ratios', + type=int, + default=20, + help= + 'Number of packing_ratio values (spaced between `min` and `max) to try.' + ) + + args = parser.parse_args() + + if not os.path.isfile(args.yaml_path): + raise FileNotFoundError( + '`yaml_path` does not correspond to any existing file.') + if args.num_devices < 1: + raise ValueError('`num_devices` must be a positive integer.') + if args.min < 1.0: + raise ValueError('`min` must be >=1.0.') + if args.max < args.min: + raise ValueError('`max` cannot be less than `min`.') + if args.num_packing_ratios < 1: + raise ValueError('`num_packing_ratios` must be a positive integer.') + return args + + args = parse_args() + + with open(args.yaml_path) as f: + cfg = om.load(f) + if 'parameters' in cfg: + cfg = om.to_container(cfg.parameters) + cfg = om.create(cfg) + device_batch_size = cfg.global_train_batch_size // args.num_devices + + # Fetch a bunch of raw examples once, which we'll re-use + if 'train_loader' not in cfg: + raise ValueError('config must define train_loader') + dataloader_cfg = cfg.train_loader + + # build tokenizer + if 'tokenizer' not in cfg: + raise ValueError('config must define tokenizer') + + resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) + if not isinstance(resolved_tokenizer_cfg, Dict): + raise ValueError( + 'tokenizer config needs to be resolved by omegaconf into a Dict.') + tokenizer_cfg = resolved_tokenizer_cfg + + tokenizer_name = tokenizer_cfg['name'] + tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, + args.num_packing_ratios, device_batch_size) + + header = '\n\n\n packing_ratio | % PADDING | % WASTE' + fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' + + print(header) + print('-' * len(header)) + for packing_ratio, padding, waste in results: + print(fstr.format(packing_ratio, padding, waste)) diff --git a/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml b/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml index 2c3fb11496..ed2e9fcac0 100644 --- a/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml +++ b/scripts/train/finetune_example/mpt-7b-arc-easy--gpu.yaml @@ -41,7 +41,10 @@ train_loader: shuffle: true max_seq_len: ${max_seq_len} decoder_only_format: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/train.py b/scripts/train/train.py index e29f2c9a47..60ee55955e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -24,9 +24,8 @@ from transformers import PreTrainedTokenizerBase from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, - MPTForCausalLM, build_finetuning_dataloader, - build_text_denoising_dataloader) -from llmfoundry.data.text_data import build_text_dataloader + MPTForCausalLM) +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.utils.builders import (build_algorithm, build_callback, build_icl_data_and_gauntlet, build_logger, build_optimizer, @@ -169,30 +168,6 @@ def print_trainable_parameters(model: torch.nn.Module) -> None: ) -def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): - if cfg.name == 'text': - return build_text_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - else: - raise ValueError(f'Not sure how to build dataloader with config: {cfg}') - - def main(cfg: DictConfig) -> Trainer: # Filter deprecation warning from torch internal usage warnings.filterwarnings( diff --git a/scripts/train/yamls/finetune/1b_local_data_sft.yaml b/scripts/train/yamls/finetune/1b_local_data_sft.yaml index 45dca2f1e0..d6f72b0c8e 100644 --- a/scripts/train/yamls/finetune/1b_local_data_sft.yaml +++ b/scripts/train/yamls/finetune/1b_local_data_sft.yaml @@ -49,7 +49,10 @@ train_loader: &train_loader allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/yamls/finetune/7b_dolly_sft.yaml b/scripts/train/yamls/finetune/7b_dolly_sft.yaml index 6483dd31f5..c5813235d9 100644 --- a/scripts/train/yamls/finetune/7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/7b_dolly_sft.yaml @@ -41,7 +41,10 @@ train_loader: allow_pad_trimming: false decoder_only_format: true shuffle: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml index 9686317bef..2f23d8e55a 100644 --- a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml @@ -31,7 +31,10 @@ train_loader: max_seq_len: ${max_seq_len} allow_pad_trimming: false decoder_only_format: true - # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with + # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion + # # of the dataset. + # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...` # # to profile this run's optimal packing_ratio as it depends on GPU count, # # batch size, sequence length # packing_ratio: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 656b6d52a6..2080ec32ec 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -8,7 +8,7 @@ import sys import tempfile from argparse import Namespace -from typing import Optional +from typing import Literal, Optional, Union from unittest.mock import MagicMock import pytest @@ -248,10 +248,11 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) -@pytest.mark.parametrize('packing_ratio', [10.0, None]) +@pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto']) def test_finetuning_dataloader(decoder_only_format: bool, allow_pad_trimming: bool, - packing_ratio: Optional[float]): + packing_ratio: Optional[Union[float, + Literal['auto']]]): # Use the datasets just built in the last test tokenizer_name = 'gpt2' if decoder_only_format else 't5-base' max_seq_len = 2048 if decoder_only_format else 1024 diff --git a/tests/test_packing.py b/tests/test_packing.py new file mode 100644 index 0000000000..cbeca8b7b1 --- /dev/null +++ b/tests/test_packing.py @@ -0,0 +1,191 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List +from unittest.mock import Mock, patch + +import pytest +import torch +from composer.utils import dist, reproducibility +from omegaconf import DictConfig +from pytest import approx +from torch.utils.data import DataLoader + +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio +from llmfoundry.utils.builders import build_tokenizer + + +def _data_to_batch(data: List[List[int]], max_seq_len: int, + pad_token_id: int) -> Dict[str, torch.Tensor]: + """Helper function to create a proper batch of data.""" + input_ids = torch.stack([ + torch.tensor(d + [pad_token_id] * (max_seq_len - len(d))) for d in data + ]) + + attention_mask = torch.stack([ + torch.tensor([1] * len(d) + [pad_token_id] * (max_seq_len - len(d))) + for d in data + ]) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def test_packing(): + """Tests that packing works for a single batch.""" + pad_token_id = 0 + max_seq_len = 5 + packer = BinPackCollator(collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right') + + batch = _data_to_batch([ + [1], + [2] * 2, + [4] * 4, + [3] * 3, + ], max_seq_len, pad_token_id) + + packed_samples = packer.pack(batch) + + assert torch.equal(packed_samples['input_ids'], + torch.Tensor([[3, 3, 3, 2, 2], [4, 4, 4, 4, 1]])) + assert torch.all(packed_samples['attention_mask'] == 1) + + +def test_packing_with_leftovers(): + """Tests that packing handles leftovers and computes waste correctly.""" + pad_token_id = 0 + max_seq_len = 5 + packer = BinPackCollator(collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right') + + batch = _data_to_batch([ + [1], + [2] * 2, + [4] * 4, + [4] * 4, + ], max_seq_len, pad_token_id) + + packed_batch = packer.pack(batch) + + assert torch.equal(packed_batch['input_ids'], + torch.Tensor([[4, 4, 4, 4, 1], [4, 4, 4, 4, 0]])) + assert torch.equal(packed_batch['attention_mask'], + torch.Tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + + # Check leftovers and waste. + assert len(packer._leftover_bins) == 1 + leftover_size, leftover = packer._leftover_bins[0] + assert leftover_size == 2 + assert torch.equal(leftover['input_ids'], torch.Tensor([2, 2])) + assert torch.equal(leftover['attention_mask'], torch.Tensor([1, 1])) + assert packer.waste == approx(2 / 11) # 2 tokens wasted of 11 tokens total + + # Ensure that leftovers are used in the next batch if possible. + batch = _data_to_batch([[1]], max_seq_len, pad_token_id) + packed_batch = packer.pack(batch) + assert torch.equal(packed_batch['input_ids'], + torch.Tensor([[2, 2, 0, 0, 0], [1, 0, 0, 0, 0]])) + assert torch.equal(packed_batch['attention_mask'], + torch.Tensor([[1, 1, 0, 0, 0], [1, 0, 0, 0, 0]])) + + +@patch('llmfoundry.data.packing.profile_packing') +def test_auto_packing(profile_packing: Mock): + """Tests that auto packing selects the highest packing ratio with zero. + + waste. + """ + # List of tuples of packing_ratio, padding, waste, sorted by packing ratio + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), (3, .7, .5)] + + packing_ratio = auto_packing_ratio( + dataloader_cfg=DictConfig({'dataset': { + 'max_seq_len': 2048 + }}), + tokenizer=None, + device_batch_size=1, + ) # Dummy values, profiling results are already set. + + # auto packing ratio should choose 2 because packing ratio is maximized while waste is 0. + assert packing_ratio == 2 + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@patch('llmfoundry.data.packing.profile_packing') +def test_dist_auto_packing(profile_packing: Mock): + """Tests that auto packing works with world size > 1.""" + dist.initialize_dist('gpu') + + # List of tuples of packing_ratio, padding, waste, sorted by packing ratio + if dist.get_global_rank() == 0: + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), + (3, .7, 0)] # should pick 3 + else: + profile_packing.return_value = [(1, .9, 0), (2, .8, 0), + (3, .7, .5)] # should pick 2 + + packing_ratio = auto_packing_ratio( + dataloader_cfg=DictConfig({'dataset': { + 'max_seq_len': 2048 + }}), + tokenizer=None, + device_batch_size=1, + ) # Dummy values, profiling results are already set. + + # auto packing ratio should choose 2 because it's the minimum between ranks. + assert packing_ratio == 2 + + +@pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) +def test_packing_with_dataloader(packing_ratio: Any): + """Tests that packing works with a dataloader.""" + reproducibility.seed_all(17) + tokenizer = build_tokenizer('gpt2', {}) + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'tatsu-lab/alpaca', + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': packing_ratio, + 'shuffle': False, + }, + 'drop_last': False, + # Need to test with 0 num_workers because the packing collator object + # Gets copied per worker and we cannot check the waste for child processes. + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + }) + + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size=6).dataloader + + assert isinstance(loader, DataLoader) + pack_collator = loader.collate_fn + assert isinstance(pack_collator, BinPackCollator) + + batch_ix = 0 + for _ in loader: + batch_ix += 1 + if batch_ix >= 3: + break + + padding = (1 - pack_collator.efficiency) + if packing_ratio == 'auto': + assert pack_collator.waste == approx(0) + assert padding == approx(0.1197916, rel=.01) + else: + assert pack_collator.waste == approx(0) + assert padding == approx(0.873720, rel=.01) From be467aee1744566e46e9b993c1ff23ab01fe5c55 Mon Sep 17 00:00:00 2001 From: Theresa Barton Date: Mon, 6 Nov 2023 10:42:26 -0800 Subject: [PATCH 03/15] Remove HumanEval tasks from ICL eval (#715) * add params * fix fsdp config * update * change model config * comment out human eval * remove boolq * actually remove real boolq * unroll boolq * really actully comment out humaneval * remove file * lint fix * lint fix --- scripts/eval/yamls/eval_gauntlet.yaml | 57 ++++++------ scripts/eval/yamls/tasks.yaml | 128 +++++++++++++------------- 2 files changed, 90 insertions(+), 95 deletions(-) diff --git a/scripts/eval/yamls/eval_gauntlet.yaml b/scripts/eval/yamls/eval_gauntlet.yaml index 1d2fa34139..791023abcf 100644 --- a/scripts/eval/yamls/eval_gauntlet.yaml +++ b/scripts/eval/yamls/eval_gauntlet.yaml @@ -133,32 +133,32 @@ eval_gauntlet: - name: boolq num_fewshot: 10 random_baseline: 0.5 - - name: programming - benchmarks: - - name: human_eval - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_cpp - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_js - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_return_simple - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_return_complex - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_25 - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_50 - num_fewshot: 0 - random_baseline: 0.0 - - name: human_eval_75 - num_fewshot: 0 - random_baseline: 0.0 + # - name: programming + # benchmarks: + # - name: human_eval + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_cpp + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_js + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_return_simple + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_return_complex + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_25 + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_50 + # num_fewshot: 0 + # random_baseline: 0.0 + # - name: human_eval_75 + # num_fewshot: 0 + # random_baseline: 0.0 - name: world_knowledge_lm_task_subscore benchmarks: - name: jeopardy @@ -258,8 +258,3 @@ eval_gauntlet: - name: squad num_fewshot: 10 random_baseline: 0 - - name: programming_lite - benchmarks: - - name: human_eval - num_fewshot: 0 - random_baseline: 0.0 diff --git a/scripts/eval/yamls/tasks.yaml b/scripts/eval/yamls/tasks.yaml index 6b66c116ea..737b08ebeb 100644 --- a/scripts/eval/yamls/tasks.yaml +++ b/scripts/eval/yamls/tasks.yaml @@ -173,67 +173,67 @@ icl_tasks: num_fewshot: [10] icl_task_type: multiple_choice continuation_delimiter: "\nAnswer: " # this separates questions from answers -- - label: human_eval - dataset_uri: eval/local_data/programming/human_eval.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_cpp - dataset_uri: eval/local_data/programming/processed_human_eval_cpp.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_js - dataset_uri: eval/local_data/programming/processed_human_eval_js.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_return_simple - dataset_uri: eval/local_data/programming/human_eval_return_simple.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_return_complex - dataset_uri: eval/local_data/programming/human_eval_return_complex.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_25 - dataset_uri: eval/local_data/programming/human_eval-0.25.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_50 - dataset_uri: eval/local_data/programming/human_eval-0.5.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation -- - label: human_eval_75 - dataset_uri: eval/local_data/programming/human_eval-0.75.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - pass_at_k: 1 - num_beams: 20 - batch_size: 1 - icl_task_type: code_evaluation +# - +# label: human_eval +# dataset_uri: eval/local_data/programming/human_eval.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_cpp +# dataset_uri: eval/local_data/programming/processed_human_eval_cpp.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_js +# dataset_uri: eval/local_data/programming/processed_human_eval_js.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_return_simple +# dataset_uri: eval/local_data/programming/human_eval_return_simple.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_return_complex +# dataset_uri: eval/local_data/programming/human_eval_return_complex.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_25 +# dataset_uri: eval/local_data/programming/human_eval-0.25.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_50 +# dataset_uri: eval/local_data/programming/human_eval-0.5.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation +# - +# label: human_eval_75 +# dataset_uri: eval/local_data/programming/human_eval-0.75.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# pass_at_k: 1 +# num_beams: 20 +# batch_size: 1 +# icl_task_type: code_evaluation From ffb58f18db01da470720366c649f7a267b9c27a5 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 6 Nov 2023 13:16:12 -0800 Subject: [PATCH 04/15] Allow logging metadata (#714) * metadata * precommit * add to config for other exp trackers * fix * pop off of config --- scripts/train/train.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/scripts/train/train.py b/scripts/train/train.py index 60ee55955e..88f776375f 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -383,6 +383,12 @@ def main(cfg: DictConfig) -> Trainer: 'compile_config', must_exist=False, default_value=None) + metadata: Optional[Dict[str, str]] = pop_config(cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True) + # Enable autoresume from model checkpoints if possible autoresume_default: bool = False if logged_cfg.get('run_name', None) is not None \ @@ -460,6 +466,14 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger = MosaicMLLogger() loggers.append(mosaicml_logger) + if metadata is not None: + # Flatten the metadata for logging + logged_cfg.pop('metadata', None) + logged_cfg.update(metadata, merge=True) + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + # Profiling profiler: Optional[Profiler] = None profiler_cfg: Optional[DictConfig] = pop_config(cfg, From c2f5742d5d15e26b510bead331b35a82258b6c44 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:01:42 -0800 Subject: [PATCH 05/15] Run HF dataset processing on local rank 0 first (#716) --- llmfoundry/data/finetuning/tasks.py | 40 ++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index edbfcc28c7..3673a48217 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,6 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets +from composer.utils import dist from omegaconf import DictConfig from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase @@ -332,6 +333,16 @@ def build_from_hf( preprocessing_fn = self.get_preprocessing_fn_from_str( proto_preprocessing_fn, dataset_name) + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' + + # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. + # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks + # can just read them. + if dist.get_local_rank() != 0: + log.debug('Waiting for local_rank 0 to finish data prep') + with dist.local_rank_zero_download_and_wait(signal_file_path): + pass + dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs) def dataset_mapper(example: Dict): @@ -340,7 +351,8 @@ def dataset_mapper(example: Dict): return _tokenize_formatted_example(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 - num_cpus_to_use = max(1, detected_cpu_count - 4) + detected_cpus_with_margin = detected_cpu_count - 8 + num_cpus_to_use = max(1, detected_cpus_with_margin) columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( @@ -348,10 +360,12 @@ def dataset_mapper(example: Dict): batched=False, remove_columns=columns_to_remove, num_proc=num_cpus_to_use, + desc='Tokenizing dataset', ) prompt_length_filtered_dataset = tokenized_dataset.filter( lambda example: len(example['input_ids']) < max_seq_len, num_proc=num_cpus_to_use, + desc='Filtering out long prompts', ) examples_removed = len(tokenized_dataset) - len( @@ -361,10 +375,16 @@ def dataset_mapper(example: Dict): f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' ) + pad_token_id = tokenizer.pad_token_id empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( lambda example: len(example['input_ids']) > 0 and len(example[ - 'labels']) > 0 and any(token_id != tokenizer.pad_token_id - for token_id in example['labels'])) + 'labels']) > 0 and any(token_id != pad_token_id + for token_id in example['labels']), + num_proc=num_cpus_to_use, + desc='Filtering out empty examples') + + log.debug('Done tokenizing and filtering examples.') + empty_examples_removed = len(prompt_length_filtered_dataset) - len( empty_examples_dropped_dataset) if empty_examples_removed > 0: @@ -372,6 +392,20 @@ def dataset_mapper(example: Dict): f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' + 'or the response was only padding tokens.') + # Now local rank 0 indicates to the other ranks that it is done + if dist.get_local_rank() == 0: + log.debug('Local rank 0 finished data prep') + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_data_prep') + + # All ranks sync up at this barrier, having completed data processing + dist.barrier() + + # Last, local rank 0 cleans up the signal file + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + + log.debug('All ranks finished data prep') return empty_examples_dropped_dataset def build_from_streaming(self, *args: Any, From 58d7cf3e3bcbbd21a77c71ff3a37e7be50d46bbe Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 6 Nov 2023 14:31:27 -0800 Subject: [PATCH 06/15] Add Hugging Face model download script (#708) * Add Hugging Face model download script * Decode response bytes to string * Clean * Move download functions to foundry utils * Clean up script * Add bs4 dependency * Fix typing * Doc formatting * Doc formatting * Fix weights preference logic * Unit tests for weights preference logic in download_from_hf_hub * Unit tests for download_from_cache_server * Add retries and unit tests * pyright * code quality checks * precommit --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/__init__.py | 4 + llmfoundry/utils/model_download_utils.py | 228 +++++++++++++++++++++ scripts/misc/download_hf_model.py | 67 ++++++ setup.py | 5 +- tests/test_model_download_utils.py | 248 +++++++++++++++++++++++ 5 files changed, 551 insertions(+), 1 deletion(-) create mode 100644 llmfoundry/utils/model_download_utils.py create mode 100644 scripts/misc/download_hf_model.py create mode 100644 tests/test_model_download_utils.py diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 38cc562c9d..7abe4dcf75 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -11,6 +11,8 @@ from llmfoundry.utils.config_utils import (calculate_batch_size_info, log_config, pop_config, update_batch_size_info) + from llmfoundry.utils.model_download_utils import ( + download_from_cache_server, download_from_hf_hub) except ImportError as e: raise ImportError( 'Please make sure to pip install . to get requirements for llm-foundry.' @@ -26,6 +28,8 @@ 'build_tokenizer', 'calculate_batch_size_info', 'convert_and_save_ft_weights', + 'download_from_cache_server', + 'download_from_hf_hub', 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', 'log_config', diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py new file mode 100644 index 0000000000..d268cb78b7 --- /dev/null +++ b/llmfoundry/utils/model_download_utils.py @@ -0,0 +1,228 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for downloading models.""" +import copy +import logging +import os +import time +from http import HTTPStatus +from typing import Optional +from urllib.parse import urljoin + +import huggingface_hub as hf_hub +import requests +import tenacity +from bs4 import BeautifulSoup +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME +from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME + +DEFAULT_IGNORE_PATTERNS = [ + '*.ckpt', + '*.h5', + '*.msgpack', +] +PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' +SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' + +log = logging.getLogger(__name__) + + +@tenacity.retry(retry=tenacity.retry_if_not_exception_type( + (ValueError, hf_hub.utils.RepositoryNotFoundError)), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10)) +def download_from_hf_hub( + repo_id: str, + save_dir: Optional[str] = None, + prefer_safetensors: bool = True, + token: Optional[str] = None, +): + """Downloads model files from a Hugging Face Hub model repo. + + Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the + Safetensors weights will be downloaded unless `prefer_safetensors` is set to False. + + Args: + repo_id (str): The Hugging Face Hub repo ID. + save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads + from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory. + prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are + available. Defaults to True. + token (str, optional): The HuggingFace API token. If not provided, the token will be read from the + `HUGGING_FACE_HUB_TOKEN` environment variable. + + Raises: + RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. + ValueError: If the model repo doesn't contain any supported model weights. + """ + repo_files = set(hf_hub.list_repo_files(repo_id)) + + # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. + ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) + + safetensors_available = (SAFE_WEIGHTS_NAME in repo_files or + SAFE_WEIGHTS_INDEX_NAME in repo_files) + pytorch_available = (PYTORCH_WEIGHTS_NAME in repo_files or + PYTORCH_WEIGHTS_INDEX_NAME in repo_files) + + if safetensors_available and pytorch_available: + if prefer_safetensors: + log.info( + 'Safetensors available and preferred. Excluding pytorch weights.' + ) + ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN) + else: + log.info( + 'Pytorch available and preferred. Excluding safetensors weights.' + ) + ignore_patterns.append(SAFE_WEIGHTS_PATTERN) + elif safetensors_available: + log.info('Only safetensors available. Ignoring weights preference.') + elif pytorch_available: + log.info('Only pytorch available. Ignoring weights preference.') + else: + raise ValueError( + f'No supported model weights found in repo {repo_id}.' + + ' Please make sure the repo contains either safetensors or pytorch weights.' + ) + + download_start = time.time() + hf_hub.snapshot_download(repo_id, + cache_dir=save_dir, + ignore_patterns=ignore_patterns, + token=token) + download_duration = time.time() - download_start + log.info( + f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds' + ) + + +def _extract_links_from_html(html: str): + """Extracts links from HTML content. + + Args: + html (str): The HTML content + + Returns: + list[str]: A list of links to download. + """ + soup = BeautifulSoup(html, 'html.parser') + links = [a['href'] for a in soup.find_all('a')] + return links + + +def _recursive_download( + session: requests.Session, + base_url: str, + path: str, + save_dir: str, + ignore_cert: bool = False, +): + """Downloads all files/subdirectories from a directory on a remote server. + + Args: + session: A requests.Session through which to make requests to the remote server. + url (str): The base URL where the files are located. + path (str): The path from the base URL to the files to download. The full URL for the download is equal to + '/'. + save_dir (str): The directory to save downloaded files to. + ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server. + Defaults to False. + WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. + + Raises: + PermissionError: If the remote server returns a 401 Unauthorized status code. + ValueError: If the remote server returns a 404 Not Found status code. + RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. + """ + url = urljoin(base_url, path) + response = session.get(url, verify=(not ignore_cert)) + + if response.status_code == HTTPStatus.UNAUTHORIZED: + raise PermissionError( + f'Not authorized to download file from {url}. Received status code {response.status_code}. ' + ) + elif response.status_code == HTTPStatus.NOT_FOUND: + raise ValueError( + f'Could not find file at {url}. Received status code {response.status_code}' + ) + elif response.status_code != HTTPStatus.OK: + raise RuntimeError( + f'Could not download file from {url}. Received unexpected status code {response.status_code}' + ) + + # Assume that the URL points to a file if it does not end with a slash. + if not path.endswith('/'): + save_path = os.path.join(save_dir, path) + parent_dir = os.path.dirname(save_path) + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + + with open(save_path, 'wb') as f: + f.write(response.content) + + log.info(f'Downloaded file {save_path}') + return + + # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links + # to download. + child_links = _extract_links_from_html(response.content.decode()) + for child_link in child_links: + _recursive_download(session, + base_url, + urljoin(path, child_link), + save_dir, + ignore_cert=ignore_cert) + + +@tenacity.retry(retry=tenacity.retry_if_not_exception_type( + (PermissionError, ValueError)), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10)) +def download_from_cache_server( + model_name: str, + cache_base_url: str, + save_dir: str, + token: Optional[str] = None, + ignore_cert: bool = False, +): + """Downloads Hugging Face models from a mirror file server. + + The file server is expected to store the files in the same structure as the Hugging Face cache + structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache. + + Args: + model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face + Hub. + cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob + files from `//blobs/`, where `formatted_model_name` is equal to + `models/` with all slashes replaced with `--`. + save_dir: The directory to save the downloaded files to. + token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` + environment variable. + ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to + False. + WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. + """ + formatted_model_name = f'models/{model_name}'.replace('/', '--') + with requests.Session() as session: + session.headers.update({'Authorization': f'Bearer {token}'}) + + download_start = time.time() + + # Only downloads the blobs in order to avoid downloading model files twice due to the + # symlnks in the Hugging Face cache structure: + _recursive_download( + session, + cache_base_url, + # Trailing slash to indicate directory + f'{formatted_model_name}/blobs/', + save_dir, + ignore_cert=ignore_cert, + ) + download_duration = time.time() - download_start + log.info( + f'Downloaded model {model_name} from cache server in {download_duration} seconds' + ) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py new file mode 100644 index 0000000000..6465a552c2 --- /dev/null +++ b/scripts/misc/download_hf_model.py @@ -0,0 +1,67 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to download model weights from Hugging Face Hub or a cache server.""" +import argparse +import logging +import os +import sys + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + +from llmfoundry.utils.model_download_utils import (download_from_cache_server, + download_from_hf_hub) + +HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' + +log = logging.getLogger(__name__) + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--model', type=str, required=True) + argparser.add_argument('--download-from', + type=str, + choices=['hf', 'cache'], + default='hf') + argparser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + argparser.add_argument('--save-dir', + type=str, + default=HUGGINGFACE_HUB_CACHE) + argparser.add_argument('--cache-url', type=str, default=None) + argparser.add_argument('--ignore-cert', action='store_true', default=False) + argparser.add_argument( + '--fallback', + action='store_true', + default=False, + help= + 'Whether to fallback to downloading from Hugging Face if download from cache fails', + ) + + args = argparser.parse_args(sys.argv[1:]) + if args.download_from == 'hf': + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + else: + try: + download_from_cache_server( + args.model, + args.cache_url, + args.save_dir, + token=args.token, + ignore_cert=args.ignore_cert, + ) + except PermissionError: + log.error(f'Not authorized to download {args.model}.') + except Exception as e: + if args.fallback: + log.warn( + f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' + ) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + else: + raise e diff --git a/setup.py b/setup.py index 63aac9d752..f528838d35 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,8 @@ 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python', 'boto3>=1.21.45,<2', 'huggingface-hub>=0.17.0,<1.0', + 'beautifulsoup4>=4.12.2,<5', # required for model download utils + 'tenacity>=8.2.3,<9', ] extra_deps = {} @@ -101,7 +103,8 @@ extra_deps['peft'] = [ 'loralib==0.1.1', # lora core 'bitsandbytes==0.39.1', # 8bit - 'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes + # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes + 'scipy>=1.10.0,<=1.11.0', # TODO: pin peft when it stabilizes. # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'peft==0.4.0', diff --git a/tests/test_model_download_utils.py b/tests/test_model_download_utils.py new file mode 100644 index 0000000000..27b9805cda --- /dev/null +++ b/tests/test_model_download_utils.py @@ -0,0 +1,248 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest.mock as mock +from http import HTTPStatus +from typing import Any, Dict, List +from unittest.mock import MagicMock +from urllib.parse import urljoin + +import pytest +import requests +import tenacity +from huggingface_hub.utils import RepositoryNotFoundError +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME +from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME + +from llmfoundry.utils.model_download_utils import (DEFAULT_IGNORE_PATTERNS, + PYTORCH_WEIGHTS_PATTERN, + SAFE_WEIGHTS_PATTERN, + download_from_cache_server, + download_from_hf_hub) + +# ======================== download_from_hf_hub tests ======================== + + +@pytest.mark.parametrize( + ['prefer_safetensors', 'repo_files', 'expected_ignore_patterns'], + [ + [ # Should use default ignore if only safetensors available + True, + [SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only safetensors available + False, + [SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ # Should use default ignore if only sharded safetensors available + True, + [SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded safetensors available + False, + [SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only pytorch available + True, + [PYTORCH_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only pytorch available + False, + [PYTORCH_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded pytorch available + True, + [PYTORCH_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded pytorch available + False, + [PYTORCH_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ # Ignore pytorch if safetensors are preferred + True, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [PYTORCH_WEIGHTS_PATTERN], + ], + [ # Ignore safetensors if pytorch is preferred + False, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [SAFE_WEIGHTS_PATTERN], + ], + [ # Ignore pytorch if safetensors are preferred + True, + [PYTORCH_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS + [PYTORCH_WEIGHTS_PATTERN], + ], + [ # Ignore safetensors if pytorch is preferred + False, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [SAFE_WEIGHTS_PATTERN], + ], + ]) +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, + prefer_safetensors: bool, + repo_files: List[str], + expected_ignore_patterns: List[str]): + test_repo_id = 'test_repo_id' + mock_list_repo_files.return_value = repo_files + + download_from_hf_hub(test_repo_id, prefer_safetensors=prefer_safetensors) + mock_snapshot_download.assert_called_once_with( + test_repo_id, + cache_dir=None, + ignore_patterns=expected_ignore_patterns, + token=None, + ) + + +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_no_weights( + mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, +): + test_repo_id = 'test_repo_id' + mock_list_repo_files.return_value = [] + + with pytest.raises(ValueError): + download_from_hf_hub(test_repo_id) + + mock_snapshot_download.assert_not_called() + + +@pytest.mark.parametrize(['exception', 'expected_attempts'], [ + [requests.exceptions.RequestException(), 3], + [RepositoryNotFoundError(''), 1], + [ValueError(), 1], +]) +@mock.patch('tenacity.nap.time.sleep') +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_retry( + mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, + mock_sleep: MagicMock, # so the retry wait doesn't actually wait + exception: BaseException, + expected_attempts: int, +): + mock_list_repo_files.return_value = [SAFE_WEIGHTS_INDEX_NAME] + mock_snapshot_download.side_effect = exception + + with pytest.raises((tenacity.RetryError, exception.__class__)): + download_from_hf_hub('test_repo_id') + + assert mock_snapshot_download.call_count == expected_attempts + + +# ======================== download_from_cache_server tests ======================== + +ROOT_HTML = b""" + + + + + + +""" + +SUBFOLDER_HTML = b""" + + + + + + +""" + + +@mock.patch.object(requests.Session, 'get') +@mock.patch('os.makedirs') +@mock.patch('builtins.open') +def test_download_from_cache_server(mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock): + cache_url = 'https://cache.com/' + model_name = 'model' + formatted_model_name = 'models--model' + save_dir = 'save_dir/' + + mock_open.return_value = MagicMock() + + def _server_response(url: str, **kwargs: Dict[str, Any]): + if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) + if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + return MagicMock(status_code=HTTPStatus.OK) + elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) + elif url == urljoin(cache_url, + f'{formatted_model_name}/blobs/folder/file2'): + return MagicMock(status_code=HTTPStatus.OK) + else: + return MagicMock(status_code=HTTPStatus.NOT_FOUND) + + mock_get.side_effect = _server_response + download_from_cache_server(model_name, cache_url, 'save_dir/') + + mock_open.assert_has_calls([ + mock.call(os.path.join(save_dir, formatted_model_name, 'blobs/file1'), + 'wb'), + mock.call( + os.path.join(save_dir, formatted_model_name, 'blobs/folder/file2'), + 'wb'), + ], + any_order=True) + + +@mock.patch.object(requests.Session, 'get') +def test_download_from_cache_server_unauthorized(mock_get: MagicMock): + cache_url = 'https://cache.com/' + model_name = 'model' + save_dir = 'save_dir/' + + mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) + with pytest.raises(PermissionError): + download_from_cache_server(model_name, cache_url, save_dir) + + +@pytest.mark.parametrize(['exception', 'expected_attempts'], [ + [requests.exceptions.RequestException(), 3], + [PermissionError(), 1], + [ValueError(), 1], +]) +@mock.patch('tenacity.nap.time.sleep') +@mock.patch('llmfoundry.utils.model_download_utils._recursive_download') +def test_download_from_cache_server_retry( + mock_recursive_download: MagicMock, + mock_sleep: MagicMock, # so the retry wait doesn't actually wait + exception: BaseException, + expected_attempts: int, +): + mock_recursive_download.side_effect = exception + + with pytest.raises((tenacity.RetryError, exception.__class__)): + download_from_cache_server('model', 'cache_url', 'save_dir') From 1d504c851c26d54e8a07b2a2245fd6cbd4d283e0 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:00:19 -0800 Subject: [PATCH 07/15] Adding support for Rotary Position Embeddings (#675) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * removed the roformer impementation of rope * .. * fixed all the lint errors * .. * .. * ../llmfoundry/models/mpt/modeling_mpt.py * .. * .. * .. * added unit test to test rotary embeddings * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Accepting the suggestion Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * incorporated some suggestions from the pr * .. * .. * .. * .. * .. * .. * .. * added mark for gpu in the rotary embedding test * .. * .. * .. * removed thecode for hf implementation of rope * .. * .. * added tests * .. * .. * ... * .. * .. * .. * .. * .. * fixed the tests after the merge * minor change * Fixed some tests failing due to a transformers library bug * added check for flash_attention before importing their rotary embedding * added check for flash_attention in tests before using dail rope * fixed tests * .. * .. * temporary fix * .. * .. * fixed a test * .. * minor change * minor changes * added documentation * added documentation * temp commit * made _set_config_defaults recursive * minor changes * reformatted tutorial table * reformatted tutorial table * reformatted tutorial table * added documentation on how to install flash attention 2 * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * .. * resolved some comments from the PR * fixed tests * modified is_flash_v2_installed * minor changes * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * resolved PR comments --------- Co-authored-by: Shashank Rajput Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 49 +- llmfoundry/models/layers/attention.py | 71 ++- llmfoundry/models/layers/blocks.py | 43 +- llmfoundry/models/mpt/configuration_mpt.py | 72 ++- llmfoundry/models/mpt/modeling_mpt.py | 129 ++++- tests/test_flash_triton_torch.py | 73 ++- tests/test_model.py | 557 +++++++++++++++++---- tests/test_rope_dail_vs_hf.py | 145 ++++++ 8 files changed, 952 insertions(+), 187 deletions(-) create mode 100644 tests/test_rope_dail_vs_hf.py diff --git a/TUTORIAL.md b/TUTORIAL.md index d019eb9f83..86bd9829e9 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -8,27 +8,42 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release. +- [LLM Foundry Tutorial](#llm-foundry-tutorial) - [Intro](#intro) - [How this repo is structured](#how-this-repo-is-structured) - [Key components](#key-components) + - [Composer](#composer) + - [StreamingDataset](#streamingdataset) + - [MCLI](#mcli) - [How the YAMLs work](#how-the-yamls-work) - [Example Workflows](#example-workflows) - [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally) - [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b) - [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b) + - [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning) + - [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation) + - [Data](#data) + - [Modeling](#modeling) - [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch) - [FAQs](#faqs) - - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) - - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) - - [What hardware can I train on?](#what-hardware-can-i-train-on) - - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) - - [What is FSDP?](#what-is-fsdp) - - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use) - - [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora) - - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) - - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) - - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) - - [Common installation issues](#common-installation-issues) + - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) + - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) + - [What hardware can I train on?](#what-hardware-can-i-train-on) + - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) + - [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on) + - [What is FSDP?](#what-is-fsdp) + - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use) + - [Limitations](#limitations) + - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir) + - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus) + - [Support for FlashAttention-2](#support-for-flashattention-2) + - [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support) + - [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora) + - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) + - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) + - [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support) + - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) + - [Common installation issues](#common-installation-issues) Let’s get started! @@ -328,6 +343,18 @@ The majority of our training setups use `triton`. --> Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes. What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. +#### Support for FlashAttention-2 +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). + +### What kinds of positional embeddings does LLM Foundry support? +Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf). + +| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes | +|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | +| ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. | +| RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 39fa7162ac..0503d6d75a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Any, List, Optional, Tuple +from typing import Any, Optional import torch import torch.nn as nn @@ -17,12 +17,13 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -def is_flash_v2_installed(): +def is_flash_v2_installed(v2_version: str = '2.0.0'): + assert version.parse(v2_version) >= version.parse('2.0.0') try: import flash_attn as flash_attn except: return False - return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + return version.parse(flash_attn.__version__) >= version.parse(v2_version) def is_flash_v1_installed(): @@ -33,6 +34,16 @@ def is_flash_v1_installed(): return version.parse(flash_attn.__version__) < version.parse('2.0.0') +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. +if is_flash_v1_installed(): + import transformers + transformers.utils.is_flash_attn_available = lambda: False + +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: # disable causal when it is not needed @@ -70,7 +81,7 @@ def scaled_multihead_dot_product_attention( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -79,7 +90,7 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: if multiquery: @@ -183,7 +194,7 @@ def scaled_multihead_dot_product_attention( def check_valid_inputs(*tensors: torch.Tensor, - valid_dtypes: Optional[List[torch.dtype]] = None): + valid_dtypes: Optional[list[torch.dtype]] = None): if valid_dtypes is None: valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: @@ -199,7 +210,7 @@ def flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -208,7 +219,7 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip @@ -337,7 +348,7 @@ def triton_flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -346,7 +357,7 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func @@ -552,12 +563,13 @@ def __init__( def forward( self, x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -581,6 +593,39 @@ def forward( query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) + if rotary_emb_w_meta_info is not None: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) + + if rotary_emb_w_meta_info['impl'] == 'dail': + value = value.view(bsz, seqlen, -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb(query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + elif rotary_emb_w_meta_info['impl'] == 'hf': + (cos, sin) = rotary_emb(value, seq_len) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, + offset_info) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + query = query.view(bsz, seqlen, self.d_model) + key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + context, attn_weights, past_key_value = self.attn_fn( query, key, @@ -677,7 +722,7 @@ def __init__( def attn_bias_shape( attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, - use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]: + use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index a08ef6d77f..6605807c6b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -12,6 +12,31 @@ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8, + 'rope': False, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +} + class MPTBlock(nn.Module): @@ -30,18 +55,7 @@ def __init__( **kwargs: Any, ): if attn_config is None: - attn_config = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, - } + attn_config = attn_config_defaults if ffn_config is None: ffn_config = { @@ -58,7 +72,8 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max' + 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', + 'rope_dail_config', 'rope_hf_config' } attn_config_subset_for_attn_class = { k: v @@ -94,6 +109,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, + rotary_emb_w_meta_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -104,6 +120,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 251e4f5caf..c4ca68d733 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,18 +8,16 @@ from transformers import PretrainedConfig -attn_config_defaults: Dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, -} +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.blocks import attn_config_defaults + +# NOTE: All utils are imported directly even if unused so that +# HuggingFace can detect all the needed files to copy into its modules folder. +# Otherwise, certain modules are missing. +# isort: off +from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', @@ -94,6 +92,16 @@ def __init__( Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. + rope (bool): Whether to use rotary positional embeddings. + rope_theta (int): The base frequency for rope. + rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py). + rope_dail_config (Dict): The configuration for the dail implementation of rope. + type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf). + pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. + xpos_scale_base (float): The scale base for XPos (if using XPos). + rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length). + type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. + factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -150,10 +158,12 @@ def __init__( del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] - if self.attn_config.get('alibi', False): + if self.attn_config.get('alibi', False) or self.attn_config.get( + 'rope', False): self.learned_pos_emb = False warnings.warn( - f'alibi is turned on, setting `learned_pos_emb` to `False.`') + f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' + ) super().__init__(**kwargs) self._validate_config() @@ -164,6 +174,10 @@ def _set_config_defaults(self, config: Dict[str, Any], for k, v in config_defaults.items(): if k not in config: config[k] = v + elif isinstance(v, dict): + # recursively set default values for any sub-dicts + config[k] = self._set_config_defaults( + config[k] if (config[k] is not None) else {}, v) return config def _validate_config(self) -> None: @@ -206,6 +220,31 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + not in ['dail', 'hf']): + raise ValueError( + 'If rope is being used then rope_impl should be either "dail", or "hf".' + ) + if self.attn_config['rope'] and ( + self.attn_config['rope_impl'] + == 'hf') and self.attn_config['rope_hf_config']['type'] not in [ + 'no_scaling', 'linear', 'dynamic' + ]: + raise ValueError( + 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' + ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + == 'dail'): + if self.attn_config['rope_dail_config']['type'] not in [ + 'original', 'xpos' + ]: + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) + if not is_flash_v2_installed(v2_version='2.0.1'): + raise ImportError( + 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' @@ -217,9 +256,10 @@ def _validate_config(self) -> None: ) if self.init_config.get('name', None) is None: raise ValueError(f"{self.init_config=} 'name' needs to be set.") - if not self.learned_pos_emb and not self.attn_config['alibi']: + if not (self.learned_pos_emb or self.attn_config['alibi'] or + self.attn_config['rope']): warnings.warn( - f'Positional information not being provided to the model using either learned_pos_emb or alibi.' + f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.' ) if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': try: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4f4581b177..0cb3ebd56c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,11 +23,27 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist + +from llmfoundry.models.layers.attention import is_flash_v2_installed + +if is_flash_v2_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn.layers.rotary import \ + RotaryEmbedding as DAILRotaryEmbedding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as HFRotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -70,6 +86,50 @@ log = logging.getLogger(__name__) +def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, + rope_dail_config: dict, rope_hf_config: dict, + max_seq_len: int): + if rope_impl == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=rope_theta, + interleaved=False, + scale_base=rope_dail_config['xpos_scale_base'] if + (rope_dail_config['type'] == 'xpos') else None, + pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_impl == 'hf': + if rope_hf_config['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_hf_config['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_hf_config['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + raise ValueError('rope_impl needs to be either dail or hf') + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -123,6 +183,18 @@ def __init__(self, config: MPTConfig): ]) self.norm_f = norm_class(config.d_model, device=config.init_device) + self.rope = config.attn_config['rope'] + self.rope_impl = None + if self.rope: + self.rope_impl = config.attn_config['rope_impl'] + self.rotary_embedding = gen_rotary_embedding( + rope_head_dim=config.d_model // config.n_heads, + rope_impl=self.rope_impl, + rope_theta=config.attn_config['rope_theta'], + rope_dail_config=config.attn_config['rope_dail_config'], + rope_hf_config=config.attn_config['rope_hf_config'], + max_seq_len=self.config.max_seq_len) + if config.init_device != 'meta': log.info( f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' @@ -361,8 +433,9 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - tok_emb = self.wte(input_ids) - if self.learned_pos_emb: + rotary_emb_w_meta_info = None + x = self.wte(input_ids) + if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: @@ -378,31 +451,44 @@ def forward( if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: + if self.learned_pos_emb and (S + past_position > + self.config.max_seq_len): raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - else: - # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled) - x = tok_emb + if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'): + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, + ) + if self.learned_pos_emb: + x = x + self.wpe(pos) + elif self.rope and self.rope_impl == 'hf': + rotary_emb_w_meta_info = { + 'impl': self.rope_impl, + 'rotary_emb': self.rotary_embedding, + 'offset_info': pos, + 'seq_len': S + past_position, + } + elif self.rope and self.rope_impl == 'dail': + rotary_emb_w_meta_info = { + 'impl': self.rope_impl, + 'rotary_emb': self.rotary_embedding, + 'offset_info': past_position, + 'seq_len': S + past_position, + } if self.embedding_fraction == 1: x = self.emb_drop(x) @@ -439,6 +525,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index e6fe8eb438..3f2c229d6d 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,6 +5,9 @@ import torch from omegaconf import OmegaConf as om +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + def allclose_helper(t0: torch.Tensor, t1: torch.Tensor, @@ -18,7 +21,32 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('attn_impl_1', ['flash', 'triton', 'torch']) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) @@ -26,18 +54,24 @@ def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, - alibi: bool, + pos_emb_config: dict, attn_type: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. + Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and + rope. """ from llmfoundry.models.layers import attention - + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') + if rope and (pos_emb_config['rope_impl'] + == 'dail') and (not is_flash_v2_installed()): + pytest.skip('dail implementation of rope requires flash attention 2.') + cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -48,7 +82,7 @@ def test_attn_impl(attn_impl_0: str, }) n, s, f = 2, 16, cfg.d_model - + assert cfg.d_model % cfg.n_heads == 0 if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 @@ -91,16 +125,45 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias = gen_bias(attn0.attn_impl) + + rotary_emb_w_meta_info = None + if rope: + rotary_embedding = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), + max_seq_len=s).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'impl': + pos_emb_config['rope_impl'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, + 'seq_len': + s, + } + y0, _, _ = attn0(x0, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) attn_bias = gen_bias(attn1.attn_impl) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/test_model.py b/tests/test_model.py index 1c7033ed48..41b62f0ccf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,7 +16,7 @@ from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device +from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -28,6 +28,7 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -517,16 +518,49 @@ def test_mpt_creation(norm_type: str, no_bias: bool): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_padding(attention_impl: str, device: str, + pos_emb_config: dict): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) + alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + rope = pos_emb_config['rope'] + if rope and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -540,7 +574,7 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, init_config={ 'name': 'baseline_', @@ -612,23 +646,35 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output + right_pad_v_left_pad_rtol = 1e-5 + right_pad_v_left_pad_atol = 1e-6 if attention_impl == 'torch' else 1e-8 + if rope and pos_emb_config['rope_impl'] == 'dail': + # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + right_pad_v_left_pad_rtol = 1e-2 + right_pad_v_left_pad_atol = 1e-2 assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], - atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + rtol=right_pad_v_left_pad_rtol, + atol=right_pad_v_left_pad_atol) + + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], atol=1e-6 if attention_impl == 'torch' else 1e-8) + # check that right padding and right padding in a batch produce the same output assert torch.allclose(right_padding_output[0, :3], batched_output[0, :3], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( middle_padding_output[0], batched_output[1, :], @@ -694,17 +740,47 @@ def test_advanced_mask_building(attention_impl: str): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -718,7 +794,7 @@ def test_generate(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, ) mpt = MPTForCausalLM(hf_config) @@ -886,9 +962,54 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache_and_padding(alibi: bool): +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_cache_and_padding(attn_impl: str, device: str, + pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -899,8 +1020,8 @@ def test_forward_with_cache_and_padding(alibi: bool): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, init_config={ @@ -910,47 +1031,74 @@ def test_forward_with_cache_and_padding(alibi: bool): ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - - first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) - first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() - - # start with passing the first three tokens through (no padding) - first_output_no_padding = mpt( - first_input_ids_no_padding, - attention_mask=first_attention_mask_no_padding) - - second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) - second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (no padding) - second_output_no_padding = mpt( - second_input_ids_no_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_no_padding, - past_key_values=first_output_no_padding.past_key_values) - - first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) - first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() - - # start with passing the first three tokens through (with left padding) - first_output_padding = mpt(first_input_ids_padding, - attention_mask=first_attention_mask_padding) - - second_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11, 11274]]) - second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (with left padding) - second_output_padding = mpt( - second_input_ids_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_padding, - past_key_values=first_output_padding.past_key_values) - - # check that the outputs are the same with or without padding - torch.testing.assert_close(second_output_no_padding.logits, - second_output_padding.logits[:, - -1, :].unsqueeze(1), - atol=1e-6, - rtol=1e-6) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) + first_input_ids_no_padding = composer_device.tensor_to_device( + first_input_ids_no_padding) + first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() + first_attention_mask_no_padding = composer_device.tensor_to_device( + first_attention_mask_no_padding) + + # start with passing the first three tokens through (no padding) + first_output_no_padding = mpt( + first_input_ids_no_padding, + attention_mask=first_attention_mask_no_padding) + + second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) + second_input_ids_no_padding = composer_device.tensor_to_device( + second_input_ids_no_padding) + second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() + second_attention_mask_no_padding = composer_device.tensor_to_device( + second_attention_mask_no_padding) + + # pass through the fourth token by itself, using the key-value cache (no padding) + second_output_no_padding = mpt( + second_input_ids_no_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_no_padding, + past_key_values=first_output_no_padding.past_key_values) + + first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) + first_input_ids_padding = composer_device.tensor_to_device( + first_input_ids_padding) + first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() + first_attention_mask_padding = composer_device.tensor_to_device( + first_attention_mask_padding) + + # start with passing the first three tokens through (with left padding) + first_output_padding = mpt(first_input_ids_padding, + attention_mask=first_attention_mask_padding) + + second_input_ids_padding = torch.tensor( + [[50256, 11274, 16390, 11, 11274]]) + second_input_ids_padding = composer_device.tensor_to_device( + second_input_ids_padding) + second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() + second_attention_mask_padding = composer_device.tensor_to_device( + second_attention_mask_padding) + + # pass through the fourth token by itself, using the key-value cache (with left padding) + second_output_padding = mpt( + second_input_ids_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_padding, + past_key_values=first_output_padding.past_key_values) + + # check that the outputs are the same with or without padding + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-2, + rtol=1e-6) + else: + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-6, + rtol=1e-6) @pytest.mark.parametrize('attn_impl,device', [ @@ -959,17 +1107,47 @@ def test_forward_with_cache_and_padding(alibi: bool): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -983,10 +1161,8 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', @@ -1066,8 +1242,53 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): ) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate_with_past_kv(alibi: bool): +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generate_with_past_kv(attn_impl: str, device: str, + pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1078,8 +1299,8 @@ def test_generate_with_past_kv(alibi: bool): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1088,33 +1309,46 @@ def test_generate_with_past_kv(alibi: bool): }, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) - with mock.patch.object(MPTForCausalLM, 'forward', - autospec=True) as forward_mocked: - forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), - past_key_values=[(torch.randn(1, 3, hf_config.d_model), - torch.randn(1, 3, hf_config.d_model)) - for _ in range(hf_config.n_layers)]) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - max_new_tokens=2) - - assert forward_mocked.call_count == 2 - _, _, kwargs = forward_mocked.mock_calls[0] - assert kwargs['past_key_values'] is None - _, _, kwargs = forward_mocked.mock_calls[1] - assert kwargs['past_key_values'] is not None - assert len(kwargs['past_key_values']) == hf_config.n_layers - assert kwargs['past_key_values'][0][0].shape == (1, 3, - hf_config.d_model) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + with mock.patch.object(MPTForCausalLM, 'forward', + autospec=True) as forward_mocked: + forward_mocked.return_value = CausalLMOutputWithPast( + logits=torch.randn((1, 3, hf_config.vocab_size)), + past_key_values=[(torch.randn(1, 3, hf_config.d_model), + torch.randn(1, 3, hf_config.d_model)) + for _ in range(hf_config.n_layers)]) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + max_new_tokens=2) + + assert forward_mocked.call_count == 2 + _, _, kwargs = forward_mocked.mock_calls[0] + assert kwargs['past_key_values'] is None + _, _, kwargs = forward_mocked.mock_calls[1] + assert kwargs['past_key_values'] is not None + assert len(kwargs['past_key_values']) == hf_config.n_layers + assert kwargs['past_key_values'][0][0].shape == (1, 3, + hf_config.d_model) +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) @pytest.mark.parametrize('generation_kwargs', [{ 'max_new_tokens': 2, 'num_beams': 4 @@ -1126,9 +1360,49 @@ def test_generate_with_past_kv(alibi: bool): 'do_sample': True, 'top_p': 0.95 }]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], - alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generation_kwargs_dont_crash(attn_impl: str, device: str, + generation_kwargs: Dict[str, Any], + pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) + if device == 'gpu': # Switch deteminism off + torch.use_deterministic_algorithms(False) hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1139,35 +1413,73 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - # no padding in the input - no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) - no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + # no padding in the input + no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) + no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - **generation_kwargs) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + **generation_kwargs) + if device == 'gpu': # Switch deteminism back on + reproducibility.configure_deterministic_mode() @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton']) -@pytest.mark.parametrize('alibi', [True, False]) -def test_model_to(attention_impl: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + pytest.skip(f'dail implementation of rope requires flash attention 2.') + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1179,7 +1491,7 @@ def test_model_to(attention_impl: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1204,7 +1516,8 @@ def test_model_to(attention_impl: str, alibi: bool): mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1221,7 +1534,8 @@ def test_model_to(attention_impl: str, alibi: bool): mpt = mpt.float() # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.half() @@ -1258,21 +1572,50 @@ def test_alibi_vs_hf(): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl: str, device: str, alibi: bool, output_attentions: bool, - output_hidden_states: bool): + attn_impl: str, device: str, pos_emb_config: dict, + output_attentions: bool, output_hidden_states: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1289,10 +1632,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py new file mode 100644 index 0000000000..598e308546 --- /dev/null +++ b/tests/test_rope_dail_vs_hf.py @@ -0,0 +1,145 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.core.precision import get_precision_context +from omegaconf import OmegaConf as om + +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + + +@pytest.mark.gpu +@pytest.mark.parametrize('clip_qkv', [True, False]) +@pytest.mark.parametrize('qk_ln', [True, False]) +@pytest.mark.parametrize( + 'attn_type', + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('seq_len', [1, 233, 2048]) +def test_rope_dail_vs_hf(clip_qkv: bool, + qk_ln: bool, + attn_type: str, + seq_len: int, + device: str = 'cuda'): + # compare rope rotations for the dail vs hf implementations + if not is_flash_v2_installed(): + pytest.skip('dail implementation of rope requires flash attention 2.') + + from llmfoundry.models.layers import attention + + cfg = om.create({ + 'attn_impl': 'flash', + 'd_model': 128, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': clip_qkv, + 'qk_ln': qk_ln, + }) + + batch_size = 2 + assert cfg.d_model % cfg.n_heads == 0 + if attn_type == 'grouped_query_attention': + cfg.kv_n_heads = 2 + + attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + + attn1.load_state_dict(attn0.state_dict()) + x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + attention_mask = torch.ones(batch_size, seq_len).to(device).bool() + + with get_precision_context('amp_bf16'): + dail_rope_config = { + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + } + } + hf_rope_config = { + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + } + } + + dail_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=dail_rope_config['rope_impl'], + rope_theta=dail_rope_config['rope_theta'], + rope_dail_config=dail_rope_config['rope_dail_config'], + rope_hf_config={}, + max_seq_len=seq_len).to('cuda') + dail_rope_w_meta_info = { + 'impl': 'dail', + 'rotary_emb': dail_rope, + 'offset_info': 0, + 'seq_len': seq_len, + } + + hf_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=hf_rope_config['rope_impl'], + rope_theta=hf_rope_config['rope_theta'], + rope_dail_config={}, + rope_hf_config=hf_rope_config['rope_hf_config'], + max_seq_len=seq_len).to('cuda') + pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + hf_rope_w_meta_info = { + 'impl': 'hf', + 'rotary_emb': hf_rope, + 'offset_info': pos, + 'seq_len': seq_len, + } + + y0, _, _ = attn0(x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=dail_rope_w_meta_info, + is_causal=True) + + y1, _, _ = attn1(x1, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=hf_rope_w_meta_info, + is_causal=True) + + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + torch.testing.assert_close(y0, y1, rtol=1e-2, atol=1e-2) + + torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + for n, p in attn0.named_parameters(): + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + torch.testing.assert_close(p, tp, rtol=1e-2, atol=1e-2) + # Relaxed to a l2-norm based check. + assert torch.norm(tp.grad - p.grad) <= 1e-2 + 1e-2 * torch.norm(p.grad) + + assert x0.grad is not None + assert x1.grad is not None + # Relaxed to a l2-norm based check. + assert torch.norm(x0.grad - x1.grad) <= 1e-2 + 1e-2 * torch.norm(x0.grad) From 2b74cb25060c0eb7c7961a0f82be3ebc4afc5e07 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 6 Nov 2023 15:33:46 -0800 Subject: [PATCH 08/15] Add databricks dependency (#717) --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index f528838d35..81178686d2 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,10 @@ 'hf_transfer==0.1.3', ] +extra_deps['databricks'] = [ + 'mosaicml[databricks]', +] + extra_deps['tensorboard'] = [ 'mosaicml[tensorboard]>=0.16.1,<0.17', ] From dd15791818fa53ae792de66d3529d94e0dcb83d9 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 6 Nov 2023 23:19:54 -0800 Subject: [PATCH 09/15] Set persistent_workers = False for packing profiling (#718) --- llmfoundry/data/finetuning/dataloader.py | 7 +++++++ llmfoundry/data/packing.py | 1 + 2 files changed, 8 insertions(+) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 6e988ac149..44d6d345f5 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -400,6 +400,13 @@ def _build_collate_fn( packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, device_batch_size) + if isinstance(packing_ratio, str): + raise ValueError( + 'dataset.packing_ratio must be a float or "auto", but it was set to ' + + f'{packing_ratio}.') + + log.info(f'Using packing ratio {packing_ratio}') + if packing_ratio == 1.0: return collate_fn, device_batch_size elif packing_ratio < 1.0: diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 1ae9efcce5..45322c9b2f 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -348,6 +348,7 @@ def profile_packing( dataloader_cfg.drop_last = False dataloader_cfg.num_workers = 0 dataloader_cfg.prefetch_factor = None + dataloader_cfg.persistent_workers = False # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] From 84c86e3b0a3b63c0c71e52f1f762325daa8adc64 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 7 Nov 2023 15:25:09 -0500 Subject: [PATCH 10/15] raise timeout (#719) --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 1151837111..ffbfac4585 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -40,7 +40,7 @@ jobs: if: github.repository_owner == 'mosaicml' with: container: ${{ matrix.container }} - mcloud-timeout: 1200 + mcloud-timeout: 1800 name: ${{ matrix.name }} pytest-command: ${{ matrix.pytest_command }} pytest-markers: ${{ matrix.markers }} From ab9b9385ed4a89749e853b59729982144bbb35f6 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 8 Nov 2023 17:22:50 -0800 Subject: [PATCH 11/15] change default overwrite to True (#724) --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3050529a5a..4f400738e4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -53,7 +53,7 @@ def __init__( save_interval: Union[str, int, Time], huggingface_folder_name: str = 'ba{batch}', precision: str = 'float32', - overwrite: bool = False, + overwrite: bool = True, mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, ): From efaa5454304f43a3d3525a54a6445b656b1cef24 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 9 Nov 2023 07:39:12 -0800 Subject: [PATCH 12/15] Attempt to fix a very occasional hang in datasets map/filter (#725) * dont use lambdas * tokenizer building distributed safety --- llmfoundry/data/finetuning/tasks.py | 16 ++++++++++++---- llmfoundry/utils/builders.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 3673a48217..67a27ac239 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -362,8 +362,12 @@ def dataset_mapper(example: Dict): num_proc=num_cpus_to_use, desc='Tokenizing dataset', ) + + def filter_long_prompts(example: Dict) -> bool: + return len(example['input_ids']) < max_seq_len + prompt_length_filtered_dataset = tokenized_dataset.filter( - lambda example: len(example['input_ids']) < max_seq_len, + filter_long_prompts, num_proc=num_cpus_to_use, desc='Filtering out long prompts', ) @@ -376,10 +380,14 @@ def dataset_mapper(example: Dict): ) pad_token_id = tokenizer.pad_token_id + + def filter_empty_examples(example: Dict) -> bool: + return len(example['input_ids']) > 0 and len( + example['labels']) > 0 and any( + token_id != pad_token_id for token_id in example['labels']) + empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( - lambda example: len(example['input_ids']) > 0 and len(example[ - 'labels']) > 0 and any(token_id != pad_token_id - for token_id in example['labels']), + filter_empty_examples, num_proc=num_cpus_to_use, desc='Filtering out empty examples') diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f027afb0ce..2251ab5fbd 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -188,6 +188,12 @@ def build_tokenizer( os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' + + # Make sure the tokenizer files are downloaded and cached first by local rank 0 + with dist.local_rank_zero_download_and_wait(signal_file_path): + pass + if tokenizer_name.startswith('tiktoken'): tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) else: @@ -202,6 +208,15 @@ def build_tokenizer( int(1e30), ) + if dist.get_local_rank() == 0: + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_tokenizer_setup') + + dist.barrier() + + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + return tokenizer From d2ddb834650b085337c4d914f77bb80c76201e9d Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:25:13 -0800 Subject: [PATCH 13/15] Add Unity Catalog support to HF checkpointer (#721) --- llmfoundry/callbacks/hf_checkpointer.py | 31 +++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4f400738e4..e02bf03693 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -14,9 +14,10 @@ from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger -from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel -from composer.utils import dist, format_name_with_dist_and_time, parse_uri +from composer.utils import (dist, format_name_with_dist_and_time, + maybe_create_remote_uploader_downloader_from_uri, + parse_uri) from composer.utils.misc import create_interval_scheduler from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -57,8 +58,7 @@ def __init__( mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, ): - self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( - save_folder) + _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite self.precision = precision self.dtype = { @@ -93,13 +93,11 @@ def __init__( self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) - self.upload_to_object_store = (self.backend != '') - if self.upload_to_object_store: - self.remote_ud = RemoteUploaderDownloader( - bucket_uri=f'{self.backend}://{self.bucket_name}', - num_concurrent_uploads=4) - else: - self.remote_ud = None + + self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( + save_folder, loggers=[]) + if self.remote_ud is not None: + self.remote_ud._num_concurrent_uploads = 4 self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] @@ -115,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: raise ValueError( f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.') - if self.upload_to_object_store and self.remote_ud is not None: + if self.remote_ud is not None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) @@ -169,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger): self.huggingface_folder_name_fstr), state.run_name, state.timestamp) dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.upload_to_object_store else contextlib.nullcontext( + ) if self.remote_ud is not None else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: @@ -233,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir) - if self.upload_to_object_store: - assert self.remote_ud is not None - log.info( - f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' - ) + if self.remote_ud is not None: + log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): self.remote_ud.upload_file( state=state, From 2f91a64a348b0f745ab83f66acdad2a07082cc14 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:40:26 -0800 Subject: [PATCH 14/15] Combine filters into one, to avoid datasets error (#729) --- llmfoundry/data/finetuning/tasks.py | 46 +++++++++++------------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 67a27ac239..6ba6ad96c8 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -363,43 +363,31 @@ def dataset_mapper(example: Dict): desc='Tokenizing dataset', ) - def filter_long_prompts(example: Dict) -> bool: - return len(example['input_ids']) < max_seq_len + pad_token_id = tokenizer.pad_token_id - prompt_length_filtered_dataset = tokenized_dataset.filter( - filter_long_prompts, + def filter_long_or_empty_examples(example: Dict) -> bool: + less_than_max_seq_len = len(example['input_ids']) < max_seq_len + non_empty_input = len(example['input_ids']) > 0 + non_empty_labels = len(example['labels']) > 0 + non_padding_response = any( + token_id != pad_token_id for token_id in example['labels']) + return (less_than_max_seq_len and non_empty_input and + non_empty_labels and non_padding_response) + + filtered_dataset = tokenized_dataset.filter( + filter_long_or_empty_examples, num_proc=num_cpus_to_use, desc='Filtering out long prompts', ) - examples_removed = len(tokenized_dataset) - len( - prompt_length_filtered_dataset) + examples_removed = len(tokenized_dataset) - len(filtered_dataset) if examples_removed > 0: warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.' ) - pad_token_id = tokenizer.pad_token_id - - def filter_empty_examples(example: Dict) -> bool: - return len(example['input_ids']) > 0 and len( - example['labels']) > 0 and any( - token_id != pad_token_id for token_id in example['labels']) - - empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( - filter_empty_examples, - num_proc=num_cpus_to_use, - desc='Filtering out empty examples') - - log.debug('Done tokenizing and filtering examples.') - - empty_examples_removed = len(prompt_length_filtered_dataset) - len( - empty_examples_dropped_dataset) - if empty_examples_removed > 0: - warnings.warn( - f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' - + 'or the response was only padding tokens.') - # Now local rank 0 indicates to the other ranks that it is done if dist.get_local_rank() == 0: log.debug('Local rank 0 finished data prep') @@ -414,7 +402,7 @@ def filter_empty_examples(example: Dict) -> bool: os.remove(signal_file_path) log.debug('All ranks finished data prep') - return empty_examples_dropped_dataset + return filtered_dataset def build_from_streaming(self, *args: Any, **kwargs: Any) -> StreamingFinetuningDataset: From 7c4d24a8bc3713d07de889df5c2f21211ae4945e Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 9 Nov 2023 17:16:39 -0800 Subject: [PATCH 15/15] Fix logging verbosity in HF model download script and repair symlinks (#727) * Make logs appear and disable InsecureRequestWarning for ignore_cert * Clean up * Repair symlinks after cache download * Clean up logging --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/model_download_utils.py | 27 +++++++++++++++--------- scripts/misc/download_hf_model.py | 20 ++++++++++++++++-- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index d268cb78b7..2104455e0f 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -6,6 +6,7 @@ import logging import os import time +import warnings from http import HTTPStatus from typing import Optional from urllib.parse import urljoin @@ -14,6 +15,7 @@ import requests import tenacity from bs4 import BeautifulSoup +from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME @@ -212,16 +214,21 @@ def download_from_cache_server( download_start = time.time() - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) + # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True + with warnings.catch_warnings(): + if ignore_cert: + warnings.simplefilter('ignore', category=InsecureRequestWarning) + + # Only downloads the blobs in order to avoid downloading model files twice due to the + # symlnks in the Hugging Face cache structure: + _recursive_download( + session, + cache_base_url, + # Trailing slash to indicate directory + f'{formatted_model_name}/blobs/', + save_dir, + ignore_cert=ignore_cert, + ) download_duration = time.time() - download_start log.info( f'Downloaded model {model_name} from cache server in {download_duration} seconds' diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index 6465a552c2..58c3445e7d 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -14,6 +14,8 @@ HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) log = logging.getLogger(__name__) if __name__ == '__main__': @@ -34,7 +36,7 @@ argparser.add_argument( '--fallback', action='store_true', - default=False, + default=True, help= 'Whether to fallback to downloading from Hugging Face if download from cache fails', ) @@ -53,11 +55,25 @@ token=args.token, ignore_cert=args.ignore_cert, ) + + # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. + # This shouldn't actually download any files if the cache server download was successful, but should address + # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. + log.info('Repairing Hugging Face cache symlinks') + + # Hide some noisy logs that aren't important for just the symlink repair. + old_level = logging.getLogger().level + logging.getLogger().setLevel(logging.ERROR) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + logging.getLogger().setLevel(old_level) + except PermissionError: log.error(f'Not authorized to download {args.model}.') except Exception as e: if args.fallback: - log.warn( + log.warning( f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' ) download_from_hf_hub(args.model,