Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically get the portion of the dataset config that is constructor args #1434

Merged
merged 14 commits into from
Aug 7, 2024
136 changes: 57 additions & 79 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -17,6 +18,8 @@
validate_target_settings,
)
from llmfoundry.data.finetuning.tasks import (
DEFAULT_TARGET_PROMPTS,
DEFAULT_TARGET_RESPONSES,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor,
Expand All @@ -39,9 +42,15 @@
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

# Default settings to use for target responses and target prompts
_DEFAULT_TARGET_RESPONSES = 'last'
_DEFAULT_TARGET_PROMPTS = 'none'
# Extra keys present in the dataset config dictionary beyond the constructor keys
_ALLOWED_DATASET_KEYS = {
'shuffle',
'packing_ratio',
'allow_pad_trimming',
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
}


def build_finetuning_dataloader(
Expand Down Expand Up @@ -171,7 +180,26 @@ def build_finetuning_dataloader(
given a starting workload YAML.
"""
dataset_cfg = dataset
_validate_config(**dataset_cfg)
is_streaming = (
dataset_cfg.get('remote') is not None or
dataset_cfg.get('streams') is not None
)
if is_streaming:
dataset_constructor_keys = inspect.signature(
b-chu marked this conversation as resolved.
Show resolved Hide resolved
dataset_constructor.streaming_dataset_class,
).parameters.keys()
else:
dataset_constructor_keys = inspect.signature(
dataset_constructor.build_from_hf,
).parameters.keys()

allowed_dataset_config_keys = set(
dataset_constructor_keys,
).union(_ALLOWED_DATASET_KEYS)
_validate_config(
**dataset_cfg,
allowed_dataset_keys=allowed_dataset_config_keys,
)

# Use EOS as the pad token if none exists
if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok)
Expand Down Expand Up @@ -213,9 +241,7 @@ def build_finetuning_dataloader(

streaming_dataset = None # for pyright
sampler = None
if dataset_cfg.get(
'remote',
) is not None or dataset_cfg.get('streams') is not None:
if is_streaming:
# Build streaming dataloader
streams_cfg = dataset_cfg.get('streams', None)
streams_cfg = to_dict_container(
Expand All @@ -225,34 +251,20 @@ def build_finetuning_dataloader(
streams_cfg,
) if streams_cfg is not None else None

# note: we don't need to use ** here because we're setting default values for almost all arguments
# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
k not in {'streams', 'packing_ratio'}
}
streaming_dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
streams=streams,
local=dataset_cfg.get('local', None),
remote=dataset_cfg.get('remote', None),
split=dataset_cfg.get('split', None),
download_retry=dataset_cfg.get('download_retry', 2),
download_timeout=dataset_cfg.get('download_timeout', 60),
validate_hash=dataset_cfg.get('validate_hash', None),
keep_zip=dataset_cfg.get('keep_zip', False),
epoch_size=dataset_cfg.get('epoch_size', None),
predownload=dataset_cfg.get('predownload', None),
cache_limit=dataset_cfg.get('cache_limit', None),
partition_algo=dataset_cfg.get('partition_algo', 'relaxed'),
num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None),
batch_size=dataloader_batch_size,
shuffle=dataset_cfg.get('shuffle', False),
shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'),
shuffle_seed=dataset_cfg.get('shuffle_seed', 9176),
shuffle_block_size=dataset_cfg.get('shuffle_block_size', None),
sampling_method=dataset_cfg.get('sampling_method', 'balanced'),
sampling_granularity=dataset_cfg.get('sampling_granularity', 1),
batching_method=dataset_cfg.get('batching_method', 'random'),
max_seq_len=dataset_cfg['max_seq_len'],
allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False),
replication=replication_factor,
packing_ratio=dataloader_batch_size / dataset_batch_size,
**dataset_constructor_args,
)

else:
Expand Down Expand Up @@ -283,24 +295,19 @@ def build_finetuning_dataloader(
dataset_name_or_path,
)

# Build dataset from HF.
# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
k not in {'split', 'preprocessing_fn'}
}
streaming_dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=dataset_cfg.get('safe_load', False),
max_seq_len=dataset_cfg['max_seq_len'],
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
target_prompts=dataset_cfg.get(
'target_prompts',
_DEFAULT_TARGET_PROMPTS,
),
target_responses=dataset_cfg.get(
'target_responses',
_DEFAULT_TARGET_RESPONSES,
),
decoder_only_format=dataset_cfg['decoder_only_format'],
hf_kwargs=dataset_cfg.get('hf_kwargs', {}),
**dataset_constructor_args,
)

# Ensure dataset is large enough.
Expand Down Expand Up @@ -367,6 +374,7 @@ def _validate_config(
streams: Optional[Dict[str, Any]] = None,
target_prompts: Optional[str] = None,
target_responses: Optional[str] = None,
allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS,
**kwargs: Dict[str, Any],
) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -417,48 +425,18 @@ def _validate_config(
Defaults to "last", meaning only the final response in multi-turn examples
will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for
details.
allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config.
kwargs (DictConfig, optional): Additional kwargs to
pass to `datasets.load_dataset`, which can be used to load
a dataset from local files.

Raises:
ValueError: If the dataset configuration does not meet the requirements.
"""
# Check for extraneous keys in the dataset config
allowed_additional_kwargs = {
'local',
'remote',
'split',
'download_retry',
'download_timeout',
'validate_hash',
'keep_zip',
'epoch_size',
'predownload',
'cache_limit',
'partition_algo',
'num_canonical_nodes',
'batch_size',
'shuffle',
'shuffle_algo',
'shuffle_seed',
'shuffle_block_size',
'sampling_method',
'sampling_granularity',
'batching_method',
'max_seq_len',
'allow_unsafe_types',
'replication',
'packing_ratio',
'allow_pad_trimming',
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
}
if not set(kwargs.keys()).issubset(allowed_additional_kwargs):
if not set(kwargs.keys()).issubset(allowed_dataset_keys):
raise ValueError(
'The dataset config contains the following extraneous keys: ' +\
', '.join(set(kwargs.keys()) - allowed_additional_kwargs),
', '.join(set(kwargs.keys()) - allowed_dataset_keys),
)

if hf_name is not None:
Expand Down Expand Up @@ -542,9 +520,9 @@ def _validate_config(
# Raise an error if the target_prompts + target_responses + decoder_only_format settings
# are invalid
if target_prompts is None:
target_prompts = _DEFAULT_TARGET_PROMPTS
target_prompts = DEFAULT_TARGET_PROMPTS
if target_responses is None:
target_responses = _DEFAULT_TARGET_RESPONSES
target_responses = DEFAULT_TARGET_RESPONSES
target_prompts, target_responses = target_prompts.lower(
), target_responses.lower()
validate_target_settings(
Expand Down Expand Up @@ -646,9 +624,9 @@ def build_collate_fn(
dataset_cfg = dataloader_cfg['dataset']
target_responses = dataset_cfg.get(
'target_responses',
_DEFAULT_TARGET_RESPONSES,
DEFAULT_TARGET_RESPONSES,
)
target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)
target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS)
max_seq_len = dataset_cfg['max_seq_len']
decoder_only_format = dataset_cfg['decoder_only_format']
allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False)
Expand Down
33 changes: 24 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -115,6 +116,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
)
SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet']
HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata']
DEFAULT_TARGET_RESPONSES = 'last'
DEFAULT_TARGET_PROMPTS = 'none'

PromptResponseDict = Mapping[str, str]
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -805,14 +808,14 @@ def build_from_hf(
self,
dataset_name: str,
split: str,
safe_load: bool,
max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
tokenizer: PreTrainedTokenizerBase,
target_prompts: str,
target_responses: str,
decoder_only_format: bool,
hf_kwargs: Dict[str, Any],
safe_load: bool = False,
max_seq_len: int = 2048,
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
target_prompts: str = DEFAULT_TARGET_PROMPTS,
target_responses: str = DEFAULT_TARGET_RESPONSES,
decoder_only_format: bool = True,
hf_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Load a HuggingFace Datasets, preprocess, and tokenize.
Expand Down Expand Up @@ -851,6 +854,14 @@ def build_from_hf(
Returns:
Dataset: The tokenized dataset.
"""
if hf_kwargs is None:
hf_kwargs = {}
b-chu marked this conversation as resolved.
Show resolved Hide resolved

# None is checked in the function, because argument defaults were added after the function was written and we want
# to preserve the ordering of the arguments for backwards compatibility.
if tokenizer is None:
raise ValueError('A tokenizer must be provided.')

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
Expand Down Expand Up @@ -999,12 +1010,16 @@ def dataset_mapper(example: Dict):
assert filtered_dataset is not None
return filtered_dataset

@property
def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]:
return StreamingFinetuningDataset

def build_from_streaming(
self,
*args: Any,
**kwargs: Any,
) -> StreamingFinetuningDataset:
return StreamingFinetuningDataset(*args, **kwargs)
return self.streaming_dataset_class(*args, **kwargs)


dataset_constructor = DatasetConstructor()
Expand Down
20 changes: 17 additions & 3 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import ContextManager, Literal, Optional, Union
from typing import Any, Callable, ContextManager, Dict, Literal, Optional, Union
from unittest.mock import MagicMock, patch

import catalogue
Expand Down Expand Up @@ -1220,6 +1220,21 @@ def test_token_counting_func_dataloader_setting(
'timeout': 0,
}

def build_from_hf(
self, # type: ignore
dataset_name: str,
split: str,
safe_load: bool = False,
max_seq_len: int = 2048,
preprocessing_fn: Optional[Callable] = None,
tokenizer: transformers.PreTrainedTokenizerBase = None,
target_prompts: str = 'last',
target_responses: str = 'none',
decoder_only_format: bool = True,
hf_kwargs: Optional[Dict[str, Any]] = None,
):
return []

if dataloader_type == 'finetuning-hf':
cfg = DictConfig({
'dataset': {
Expand All @@ -1235,8 +1250,7 @@ def test_token_counting_func_dataloader_setting(
})
monkeypatch.setattr(
'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf',
lambda *args,
**kwargs: [],
build_from_hf,
)
dl = build_finetuning_dataloader(
tokenizer=gptt,
Expand Down
Loading