diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index d9450bc657..771033a703 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import inspect import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -17,6 +18,8 @@ validate_target_settings, ) from llmfoundry.data.finetuning.tasks import ( + DEFAULT_TARGET_PROMPTS, + DEFAULT_TARGET_RESPONSES, DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor, @@ -39,9 +42,15 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -# Default settings to use for target responses and target prompts -_DEFAULT_TARGET_RESPONSES = 'last' -_DEFAULT_TARGET_PROMPTS = 'none' +# Extra keys present in the dataset config dictionary beyond the constructor keys +_ALLOWED_DATASET_KEYS = { + 'shuffle', + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + 'max_leftover_bins_to_keep', +} def build_finetuning_dataloader( @@ -171,7 +180,26 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - _validate_config(**dataset_cfg) + is_streaming = ( + dataset_cfg.get('remote') is not None or + dataset_cfg.get('streams') is not None + ) + if is_streaming: + dataset_constructor_keys = inspect.signature( + dataset_constructor.streaming_dataset_class, + ).parameters.keys() + else: + dataset_constructor_keys = inspect.signature( + dataset_constructor.build_from_hf, + ).parameters.keys() + + allowed_dataset_config_keys = set( + dataset_constructor_keys, + ).union(_ALLOWED_DATASET_KEYS) + _validate_config( + **dataset_cfg, + allowed_dataset_keys=allowed_dataset_config_keys, + ) # Use EOS as the pad token if none exists if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok) @@ -213,9 +241,7 @@ def build_finetuning_dataloader( streaming_dataset = None # for pyright sampler = None - if dataset_cfg.get( - 'remote', - ) is not None or dataset_cfg.get('streams') is not None: + if is_streaming: # Build streaming dataloader streams_cfg = dataset_cfg.get('streams', None) streams_cfg = to_dict_container( @@ -225,34 +251,20 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # note: we don't need to use ** here because we're setting default values for almost all arguments + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'streams', 'packing_ratio'} + } streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, - local=dataset_cfg.get('local', None), - remote=dataset_cfg.get('remote', None), - split=dataset_cfg.get('split', None), - download_retry=dataset_cfg.get('download_retry', 2), - download_timeout=dataset_cfg.get('download_timeout', 60), - validate_hash=dataset_cfg.get('validate_hash', None), - keep_zip=dataset_cfg.get('keep_zip', False), - epoch_size=dataset_cfg.get('epoch_size', None), - predownload=dataset_cfg.get('predownload', None), - cache_limit=dataset_cfg.get('cache_limit', None), - partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), - num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), batch_size=dataloader_batch_size, - shuffle=dataset_cfg.get('shuffle', False), - shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), - shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), - shuffle_block_size=dataset_cfg.get('shuffle_block_size', None), - sampling_method=dataset_cfg.get('sampling_method', 'balanced'), - sampling_granularity=dataset_cfg.get('sampling_granularity', 1), - batching_method=dataset_cfg.get('batching_method', 'random'), - max_seq_len=dataset_cfg['max_seq_len'], - allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, packing_ratio=dataloader_batch_size / dataset_batch_size, + **dataset_constructor_args, ) else: @@ -283,24 +295,19 @@ def build_finetuning_dataloader( dataset_name_or_path, ) - # Build dataset from HF. + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'split', 'preprocessing_fn'} + } streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, - safe_load=dataset_cfg.get('safe_load', False), - max_seq_len=dataset_cfg['max_seq_len'], preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=dataset_cfg.get( - 'target_prompts', - _DEFAULT_TARGET_PROMPTS, - ), - target_responses=dataset_cfg.get( - 'target_responses', - _DEFAULT_TARGET_RESPONSES, - ), - decoder_only_format=dataset_cfg['decoder_only_format'], - hf_kwargs=dataset_cfg.get('hf_kwargs', {}), + **dataset_constructor_args, ) # Ensure dataset is large enough. @@ -367,6 +374,7 @@ def _validate_config( streams: Optional[Dict[str, Any]] = None, target_prompts: Optional[str] = None, target_responses: Optional[str] = None, + allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: Dict[str, Any], ) -> None: """Validates the dataset configuration. @@ -417,6 +425,7 @@ def _validate_config( Defaults to "last", meaning only the final response in multi-turn examples will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for details. + allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config. kwargs (DictConfig, optional): Additional kwargs to pass to `datasets.load_dataset`, which can be used to load a dataset from local files. @@ -424,41 +433,10 @@ def _validate_config( Raises: ValueError: If the dataset configuration does not meet the requirements. """ - # Check for extraneous keys in the dataset config - allowed_additional_kwargs = { - 'local', - 'remote', - 'split', - 'download_retry', - 'download_timeout', - 'validate_hash', - 'keep_zip', - 'epoch_size', - 'predownload', - 'cache_limit', - 'partition_algo', - 'num_canonical_nodes', - 'batch_size', - 'shuffle', - 'shuffle_algo', - 'shuffle_seed', - 'shuffle_block_size', - 'sampling_method', - 'sampling_granularity', - 'batching_method', - 'max_seq_len', - 'allow_unsafe_types', - 'replication', - 'packing_ratio', - 'allow_pad_trimming', - 'seq_parallel_replication', - 'auto_packing_replication', - 'max_leftover_bins_to_keep', - } - if not set(kwargs.keys()).issubset(allowed_additional_kwargs): + if not set(kwargs.keys()).issubset(allowed_dataset_keys): raise ValueError( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_additional_kwargs), + ', '.join(set(kwargs.keys()) - allowed_dataset_keys), ) if hf_name is not None: @@ -542,9 +520,9 @@ def _validate_config( # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid if target_prompts is None: - target_prompts = _DEFAULT_TARGET_PROMPTS + target_prompts = DEFAULT_TARGET_PROMPTS if target_responses is None: - target_responses = _DEFAULT_TARGET_RESPONSES + target_responses = DEFAULT_TARGET_RESPONSES target_prompts, target_responses = target_prompts.lower( ), target_responses.lower() validate_target_settings( @@ -646,9 +624,9 @@ def build_collate_fn( dataset_cfg = dataloader_cfg['dataset'] target_responses = dataset_cfg.get( 'target_responses', - _DEFAULT_TARGET_RESPONSES, + DEFAULT_TARGET_RESPONSES, ) - target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS) + target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS) max_seq_len = dataset_cfg['max_seq_len'] decoder_only_format = dataset_cfg['decoder_only_format'] allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 397b619e73..dd9b495ce4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -115,6 +116,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata'] +DEFAULT_TARGET_RESPONSES = 'last' +DEFAULT_TARGET_PROMPTS = 'none' PromptResponseDict = Mapping[str, str] ChatFormattedDict = Mapping[str, List[Dict[str, str]]] @@ -805,14 +808,14 @@ def build_from_hf( self, dataset_name: str, split: str, - safe_load: bool, - max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], - tokenizer: PreTrainedTokenizerBase, - target_prompts: str, - target_responses: str, - decoder_only_format: bool, - hf_kwargs: Dict[str, Any], + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + target_prompts: str = DEFAULT_TARGET_PROMPTS, + target_responses: str = DEFAULT_TARGET_RESPONSES, + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -851,6 +854,14 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ + if hf_kwargs is None: + hf_kwargs = {} + + # None is checked in the function, because argument defaults were added after the function was written and we want + # to preserve the ordering of the arguments for backwards compatibility. + if tokenizer is None: + raise ValueError('A tokenizer must be provided.') + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -999,12 +1010,16 @@ def dataset_mapper(example: Dict): assert filtered_dataset is not None return filtered_dataset + @property + def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]: + return StreamingFinetuningDataset + def build_from_streaming( self, *args: Any, **kwargs: Any, ) -> StreamingFinetuningDataset: - return StreamingFinetuningDataset(*args, **kwargs) + return self.streaming_dataset_class(*args, **kwargs) dataset_constructor = DatasetConstructor() diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 8e92658194..1a43e12536 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,7 @@ import shutil from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union +from typing import Any, Callable, ContextManager, Dict, Literal, Optional, Union from unittest.mock import MagicMock, patch import catalogue @@ -1220,6 +1220,21 @@ def test_token_counting_func_dataloader_setting( 'timeout': 0, } + def build_from_hf( + self, # type: ignore + dataset_name: str, + split: str, + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable] = None, + tokenizer: transformers.PreTrainedTokenizerBase = None, + target_prompts: str = 'last', + target_responses: str = 'none', + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, + ): + return [] + if dataloader_type == 'finetuning-hf': cfg = DictConfig({ 'dataset': { @@ -1235,8 +1250,7 @@ def test_token_counting_func_dataloader_setting( }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - lambda *args, - **kwargs: [], + build_from_hf, ) dl = build_finetuning_dataloader( tokenizer=gptt,