From a42e72371fdd2e98ce97d6b98252edba7dbc2467 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:10:56 -0700 Subject: [PATCH 01/14] try it --- llmfoundry/data/finetuning/dataloader.py | 128 +++++++++-------------- llmfoundry/data/finetuning/tasks.py | 31 ++++-- 2 files changed, 72 insertions(+), 87 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index d9450bc657..037a60f1bd 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import inspect import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -17,6 +18,8 @@ validate_target_settings, ) from llmfoundry.data.finetuning.tasks import ( + DEFAULT_TARGET_PROMPTS, + DEFAULT_TARGET_RESPONSES, DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor, @@ -39,10 +42,6 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -# Default settings to use for target responses and target prompts -_DEFAULT_TARGET_RESPONSES = 'last' -_DEFAULT_TARGET_PROMPTS = 'none' - def build_finetuning_dataloader( tokenizer: PreTrainedTokenizerBase, @@ -171,7 +170,25 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - _validate_config(**dataset_cfg) + is_streaming = dataset_cfg.get('remote') is not None or dataset_cfg.get( + 'streams', + ) is not None + if is_streaming: + dataset_constructor_keys = inspect.signature( + dataset_constructor.streaming_dataset_class, + ).parameters.keys() + else: + dataset_constructor_keys = inspect.signature( + dataset_constructor.build_from_hf, + ).parameters.keys() + + allowed_dataset_config_keys = set( + dataset_constructor_keys, + ).union(_ALLOWED_DATASET_KEYS) + _validate_config( + **dataset_cfg, + allowed_dataset_keys=allowed_dataset_config_keys, + ) # Use EOS as the pad token if none exists if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok) @@ -213,9 +230,7 @@ def build_finetuning_dataloader( streaming_dataset = None # for pyright sampler = None - if dataset_cfg.get( - 'remote', - ) is not None or dataset_cfg.get('streams') is not None: + if is_streaming: # Build streaming dataloader streams_cfg = dataset_cfg.get('streams', None) streams_cfg = to_dict_container( @@ -229,30 +244,14 @@ def build_finetuning_dataloader( streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, - local=dataset_cfg.get('local', None), - remote=dataset_cfg.get('remote', None), - split=dataset_cfg.get('split', None), - download_retry=dataset_cfg.get('download_retry', 2), - download_timeout=dataset_cfg.get('download_timeout', 60), - validate_hash=dataset_cfg.get('validate_hash', None), - keep_zip=dataset_cfg.get('keep_zip', False), - epoch_size=dataset_cfg.get('epoch_size', None), - predownload=dataset_cfg.get('predownload', None), - cache_limit=dataset_cfg.get('cache_limit', None), - partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), - num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), batch_size=dataloader_batch_size, - shuffle=dataset_cfg.get('shuffle', False), - shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), - shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), - shuffle_block_size=dataset_cfg.get('shuffle_block_size', None), - sampling_method=dataset_cfg.get('sampling_method', 'balanced'), - sampling_granularity=dataset_cfg.get('sampling_granularity', 1), - batching_method=dataset_cfg.get('batching_method', 'random'), - max_seq_len=dataset_cfg['max_seq_len'], - allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, packing_ratio=dataloader_batch_size / dataset_batch_size, + **{ + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys + }, ) else: @@ -287,20 +286,13 @@ def build_finetuning_dataloader( streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, - safe_load=dataset_cfg.get('safe_load', False), - max_seq_len=dataset_cfg['max_seq_len'], preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=dataset_cfg.get( - 'target_prompts', - _DEFAULT_TARGET_PROMPTS, - ), - target_responses=dataset_cfg.get( - 'target_responses', - _DEFAULT_TARGET_RESPONSES, - ), - decoder_only_format=dataset_cfg['decoder_only_format'], - hf_kwargs=dataset_cfg.get('hf_kwargs', {}), + **{ + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys + }, ) # Ensure dataset is large enough. @@ -355,6 +347,15 @@ def build_finetuning_dataloader( ) +_ALLOWED_DATASET_KEYS = { + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + 'max_leftover_bins_to_keep', +} + + def _validate_config( max_seq_len: int, decoder_only_format: bool = False, @@ -367,6 +368,7 @@ def _validate_config( streams: Optional[Dict[str, Any]] = None, target_prompts: Optional[str] = None, target_responses: Optional[str] = None, + allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: Dict[str, Any], ) -> None: """Validates the dataset configuration. @@ -417,6 +419,7 @@ def _validate_config( Defaults to "last", meaning only the final response in multi-turn examples will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for details. + allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config. kwargs (DictConfig, optional): Additional kwargs to pass to `datasets.load_dataset`, which can be used to load a dataset from local files. @@ -424,41 +427,10 @@ def _validate_config( Raises: ValueError: If the dataset configuration does not meet the requirements. """ - # Check for extraneous keys in the dataset config - allowed_additional_kwargs = { - 'local', - 'remote', - 'split', - 'download_retry', - 'download_timeout', - 'validate_hash', - 'keep_zip', - 'epoch_size', - 'predownload', - 'cache_limit', - 'partition_algo', - 'num_canonical_nodes', - 'batch_size', - 'shuffle', - 'shuffle_algo', - 'shuffle_seed', - 'shuffle_block_size', - 'sampling_method', - 'sampling_granularity', - 'batching_method', - 'max_seq_len', - 'allow_unsafe_types', - 'replication', - 'packing_ratio', - 'allow_pad_trimming', - 'seq_parallel_replication', - 'auto_packing_replication', - 'max_leftover_bins_to_keep', - } - if not set(kwargs.keys()).issubset(allowed_additional_kwargs): + if not set(kwargs.keys()).issubset(allowed_dataset_keys): raise ValueError( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_additional_kwargs), + ', '.join(set(kwargs.keys()) - allowed_dataset_keys), ) if hf_name is not None: @@ -542,9 +514,9 @@ def _validate_config( # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid if target_prompts is None: - target_prompts = _DEFAULT_TARGET_PROMPTS + target_prompts = DEFAULT_TARGET_PROMPTS if target_responses is None: - target_responses = _DEFAULT_TARGET_RESPONSES + target_responses = DEFAULT_TARGET_RESPONSES target_prompts, target_responses = target_prompts.lower( ), target_responses.lower() validate_target_settings( @@ -646,9 +618,9 @@ def build_collate_fn( dataset_cfg = dataloader_cfg['dataset'] target_responses = dataset_cfg.get( 'target_responses', - _DEFAULT_TARGET_RESPONSES, + DEFAULT_TARGET_RESPONSES, ) - target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS) + target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS) max_seq_len = dataset_cfg['max_seq_len'] decoder_only_format = dataset_cfg['decoder_only_format'] allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 397b619e73..74dad48939 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -115,6 +116,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata'] +DEFAULT_TARGET_RESPONSES = 'all' +DEFAULT_TARGET_PROMPTS = 'none' PromptResponseDict = Mapping[str, str] ChatFormattedDict = Mapping[str, List[Dict[str, str]]] @@ -805,14 +808,14 @@ def build_from_hf( self, dataset_name: str, split: str, - safe_load: bool, - max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], - tokenizer: PreTrainedTokenizerBase, - target_prompts: str, - target_responses: str, - decoder_only_format: bool, - hf_kwargs: Dict[str, Any], + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + target_prompts: str = DEFAULT_TARGET_PROMPTS, + target_responses: str = DEFAULT_TARGET_RESPONSES, + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -851,6 +854,12 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ + if hf_kwargs is None: + hf_kwargs = {} + + if tokenizer is None: + raise ValueError('A tokenizer must be provided.') + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. @@ -999,12 +1008,16 @@ def dataset_mapper(example: Dict): assert filtered_dataset is not None return filtered_dataset + @property + def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]: + return StreamingFinetuningDataset + def build_from_streaming( self, *args: Any, **kwargs: Any, ) -> StreamingFinetuningDataset: - return StreamingFinetuningDataset(*args, **kwargs) + return self.streaming_dataset_class(*args, **kwargs) dataset_constructor = DatasetConstructor() From d0517b7762aafa718959f378e66bb2a4ee0f8c64 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:14:28 -0700 Subject: [PATCH 02/14] put default back --- llmfoundry/data/finetuning/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 74dad48939..e4887c28a5 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -116,7 +116,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata'] -DEFAULT_TARGET_RESPONSES = 'all' +DEFAULT_TARGET_RESPONSES = 'last' DEFAULT_TARGET_PROMPTS = 'none' PromptResponseDict = Mapping[str, str] From b7b264e520527a28c52be598eecfaf4c1122077e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:16:44 -0700 Subject: [PATCH 03/14] add comment --- llmfoundry/data/finetuning/tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e4887c28a5..dd9b495ce4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -857,6 +857,8 @@ def build_from_hf( if hf_kwargs is None: hf_kwargs = {} + # None is checked in the function, because argument defaults were added after the function was written and we want + # to preserve the ordering of the arguments for backwards compatibility. if tokenizer is None: raise ValueError('A tokenizer must be provided.') From d92ba17a2464084aa97f96d0da4a12591b53a2ac Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:19:05 -0700 Subject: [PATCH 04/14] move const to top --- llmfoundry/data/finetuning/dataloader.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 037a60f1bd..abfacd4f33 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -42,6 +42,15 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 +# Extra keys present in the dataset config dictionary beyond the constructor keys +_ALLOWED_DATASET_KEYS = { + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + 'max_leftover_bins_to_keep', +} + def build_finetuning_dataloader( tokenizer: PreTrainedTokenizerBase, @@ -347,15 +356,6 @@ def build_finetuning_dataloader( ) -_ALLOWED_DATASET_KEYS = { - 'packing_ratio', - 'allow_pad_trimming', - 'seq_parallel_replication', - 'auto_packing_replication', - 'max_leftover_bins_to_keep', -} - - def _validate_config( max_seq_len: int, decoder_only_format: bool = False, From e8c4937c44907f66e57fb289c5e9e494b4e3894c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:27:11 -0700 Subject: [PATCH 05/14] put shuffle back --- llmfoundry/data/finetuning/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index abfacd4f33..0f8fb4d050 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -44,6 +44,7 @@ # Extra keys present in the dataset config dictionary beyond the constructor keys _ALLOWED_DATASET_KEYS = { + 'shuffle', 'packing_ratio', 'allow_pad_trimming', 'seq_parallel_replication', From 319a7f167e592dae195a05d839c1f8e94e3e7b2f Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:36:17 -0700 Subject: [PATCH 06/14] special case split --- llmfoundry/data/finetuning/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 0f8fb4d050..c4b194aeb7 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -301,7 +301,7 @@ def build_finetuning_dataloader( **{ k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys + if k in dataset_constructor_keys and k != 'split' }, ) From e5bc5b7819d2781d6299a7a3de419380eddc08e5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:52:08 -0700 Subject: [PATCH 07/14] try again --- llmfoundry/data/finetuning/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index c4b194aeb7..26e48d0b96 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -301,7 +301,7 @@ def build_finetuning_dataloader( **{ k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and k != 'split' + if k in dataset_constructor_keys and k not in {'split', 'preprocessing_fn'} }, ) From 3da3aea77a1a6f5cdbd7d5466fa73539edb4cd25 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 17:53:22 -0700 Subject: [PATCH 08/14] pc --- llmfoundry/data/finetuning/dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 26e48d0b96..1fa4675e32 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -301,7 +301,8 @@ def build_finetuning_dataloader( **{ k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and k not in {'split', 'preprocessing_fn'} + if k in dataset_constructor_keys and + k not in {'split', 'preprocessing_fn'} }, ) From cc664a02b89425ad60ca5cc5d40c7f3280909d56 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 18:14:41 -0700 Subject: [PATCH 09/14] pr --- llmfoundry/data/finetuning/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 1fa4675e32..66b90b809e 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -260,7 +260,7 @@ def build_finetuning_dataloader( **{ k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys + if k in dataset_constructor_keys and k not in {'packing_ratio'} }, ) From 3a76ea205aee83ceacc571d5cf7eeb2b9d9d7ffc Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 18:39:35 -0700 Subject: [PATCH 10/14] fix patch --- llmfoundry/data/finetuning/dataloader.py | 6 ++++++ tests/data/test_dataloader.py | 20 +++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 66b90b809e..b5dad58506 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -184,17 +184,23 @@ def build_finetuning_dataloader( 'streams', ) is not None if is_streaming: + print('is_streaming') dataset_constructor_keys = inspect.signature( dataset_constructor.streaming_dataset_class, ).parameters.keys() else: + print('is not streaming') dataset_constructor_keys = inspect.signature( dataset_constructor.build_from_hf, ).parameters.keys() + print(inspect.signature(dataset_constructor.build_from_hf)) + print(dataset_constructor_keys) allowed_dataset_config_keys = set( dataset_constructor_keys, ).union(_ALLOWED_DATASET_KEYS) + print(allowed_dataset_config_keys) + print(dataset_constructor_keys) _validate_config( **dataset_cfg, allowed_dataset_keys=allowed_dataset_config_keys, diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 8e92658194..76a6ebebe3 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,7 @@ import shutil from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union +from typing import ContextManager, Literal, Optional, Union, Dict, Any, Callable from unittest.mock import MagicMock, patch import catalogue @@ -1220,6 +1220,21 @@ def test_token_counting_func_dataloader_setting( 'timeout': 0, } + def build_from_hf( + self, # type: ignore + dataset_name: str, + split: str, + safe_load: bool = False, + max_seq_len: int = 2048, + preprocessing_fn: Callable = None, + tokenizer: transformers.PreTrainedTokenizerBase = None, + target_prompts: str = 'last', + target_responses: str = 'none', + decoder_only_format: bool = True, + hf_kwargs: Optional[Dict[str, Any]] = None, + ): + return [] + if dataloader_type == 'finetuning-hf': cfg = DictConfig({ 'dataset': { @@ -1235,8 +1250,7 @@ def test_token_counting_func_dataloader_setting( }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - lambda *args, - **kwargs: [], + build_from_hf ) dl = build_finetuning_dataloader( tokenizer=gptt, From e73f1866be83a0aa8fba3ea130c2b8c9f838b02e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 18:41:53 -0700 Subject: [PATCH 11/14] pc --- tests/data/test_dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 76a6ebebe3..1a43e12536 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -7,7 +7,7 @@ import shutil from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union, Dict, Any, Callable +from typing import Any, Callable, ContextManager, Dict, Literal, Optional, Union from unittest.mock import MagicMock, patch import catalogue @@ -1226,7 +1226,7 @@ def build_from_hf( split: str, safe_load: bool = False, max_seq_len: int = 2048, - preprocessing_fn: Callable = None, + preprocessing_fn: Optional[Callable] = None, tokenizer: transformers.PreTrainedTokenizerBase = None, target_prompts: str = 'last', target_responses: str = 'none', @@ -1250,7 +1250,7 @@ def build_from_hf( }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - build_from_hf + build_from_hf, ) dl = build_finetuning_dataloader( tokenizer=gptt, From 8c9f90a798d16d79731caf7aa1cd86f3fb156652 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 6 Aug 2024 19:03:25 -0700 Subject: [PATCH 12/14] fix again --- llmfoundry/data/finetuning/dataloader.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index b5dad58506..e1ba405efe 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -184,23 +184,17 @@ def build_finetuning_dataloader( 'streams', ) is not None if is_streaming: - print('is_streaming') dataset_constructor_keys = inspect.signature( dataset_constructor.streaming_dataset_class, ).parameters.keys() else: - print('is not streaming') dataset_constructor_keys = inspect.signature( dataset_constructor.build_from_hf, ).parameters.keys() - print(inspect.signature(dataset_constructor.build_from_hf)) - print(dataset_constructor_keys) allowed_dataset_config_keys = set( dataset_constructor_keys, ).union(_ALLOWED_DATASET_KEYS) - print(allowed_dataset_config_keys) - print(dataset_constructor_keys) _validate_config( **dataset_cfg, allowed_dataset_keys=allowed_dataset_config_keys, @@ -266,7 +260,8 @@ def build_finetuning_dataloader( **{ k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and k not in {'packing_ratio'} + if k in dataset_constructor_keys and + k not in {'packing_ratio', 'streams'} }, ) From 021947e1dc26dfd5fcbbddbdf8e45d7a64b73a29 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 7 Aug 2024 10:15:05 -0700 Subject: [PATCH 13/14] un inline --- llmfoundry/data/finetuning/dataloader.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index e1ba405efe..c15a7fce83 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -250,19 +250,20 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # note: we don't need to use ** here because we're setting default values for almost all arguments + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'streams', 'packing_ratio'} + } streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, batch_size=dataloader_batch_size, replication=replication_factor, packing_ratio=dataloader_batch_size / dataset_batch_size, - **{ - k: v - for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and - k not in {'packing_ratio', 'streams'} - }, + **dataset_constructor_args, ) else: @@ -293,18 +294,19 @@ def build_finetuning_dataloader( dataset_name_or_path, ) - # Build dataset from HF. + # Take the constructor args from above, minus args that have been created separately + dataset_constructor_args = { + k: v + for k, v in dataset_cfg.items() + if k in dataset_constructor_keys and + k not in {'split', 'preprocessing_fn'} + } streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - **{ - k: v - for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and - k not in {'split', 'preprocessing_fn'} - }, + **dataset_constructor_args, ) # Ensure dataset is large enough. From f479ed92db86dbb4c6655f3a7c9e14453b189813 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 7 Aug 2024 12:44:28 -0700 Subject: [PATCH 14/14] format --- llmfoundry/data/finetuning/dataloader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index c15a7fce83..771033a703 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -180,9 +180,10 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - is_streaming = dataset_cfg.get('remote') is not None or dataset_cfg.get( - 'streams', - ) is not None + is_streaming = ( + dataset_cfg.get('remote') is not None or + dataset_cfg.get('streams') is not None + ) if is_streaming: dataset_constructor_keys = inspect.signature( dataset_constructor.streaming_dataset_class,