Skip to content

Commit

Permalink
Merge branch 'main' into milo/fix-errors
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Aug 8, 2024
2 parents 29ea855 + f006d07 commit fc2aa31
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 111 deletions.
136 changes: 57 additions & 79 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -17,6 +18,8 @@
validate_target_settings,
)
from llmfoundry.data.finetuning.tasks import (
DEFAULT_TARGET_PROMPTS,
DEFAULT_TARGET_RESPONSES,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor,
Expand All @@ -39,9 +42,15 @@
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

# Default settings to use for target responses and target prompts
_DEFAULT_TARGET_RESPONSES = 'last'
_DEFAULT_TARGET_PROMPTS = 'none'
# Extra keys present in the dataset config dictionary beyond the constructor keys
_ALLOWED_DATASET_KEYS = {
'shuffle',
'packing_ratio',
'allow_pad_trimming',
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
}


def build_finetuning_dataloader(
Expand Down Expand Up @@ -171,7 +180,26 @@ def build_finetuning_dataloader(
given a starting workload YAML.
"""
dataset_cfg = dataset
_validate_config(**dataset_cfg)
is_streaming = (
dataset_cfg.get('remote') is not None or
dataset_cfg.get('streams') is not None
)
if is_streaming:
dataset_constructor_keys = inspect.signature(
dataset_constructor.streaming_dataset_class,
).parameters.keys()
else:
dataset_constructor_keys = inspect.signature(
dataset_constructor.build_from_hf,
).parameters.keys()

allowed_dataset_config_keys = set(
dataset_constructor_keys,
).union(_ALLOWED_DATASET_KEYS)
_validate_config(
**dataset_cfg,
allowed_dataset_keys=allowed_dataset_config_keys,
)

# Use EOS as the pad token if none exists
if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok)
Expand Down Expand Up @@ -213,9 +241,7 @@ def build_finetuning_dataloader(

streaming_dataset = None # for pyright
sampler = None
if dataset_cfg.get(
'remote',
) is not None or dataset_cfg.get('streams') is not None:
if is_streaming:
# Build streaming dataloader
streams_cfg = dataset_cfg.get('streams', None)
streams_cfg = to_dict_container(
Expand All @@ -225,34 +251,20 @@ def build_finetuning_dataloader(
streams_cfg,
) if streams_cfg is not None else None

# note: we don't need to use ** here because we're setting default values for almost all arguments
# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
k not in {'streams', 'packing_ratio'}
}
streaming_dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
streams=streams,
local=dataset_cfg.get('local', None),
remote=dataset_cfg.get('remote', None),
split=dataset_cfg.get('split', None),
download_retry=dataset_cfg.get('download_retry', 2),
download_timeout=dataset_cfg.get('download_timeout', 60),
validate_hash=dataset_cfg.get('validate_hash', None),
keep_zip=dataset_cfg.get('keep_zip', False),
epoch_size=dataset_cfg.get('epoch_size', None),
predownload=dataset_cfg.get('predownload', None),
cache_limit=dataset_cfg.get('cache_limit', None),
partition_algo=dataset_cfg.get('partition_algo', 'relaxed'),
num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None),
batch_size=dataloader_batch_size,
shuffle=dataset_cfg.get('shuffle', False),
shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'),
shuffle_seed=dataset_cfg.get('shuffle_seed', 9176),
shuffle_block_size=dataset_cfg.get('shuffle_block_size', None),
sampling_method=dataset_cfg.get('sampling_method', 'balanced'),
sampling_granularity=dataset_cfg.get('sampling_granularity', 1),
batching_method=dataset_cfg.get('batching_method', 'random'),
max_seq_len=dataset_cfg['max_seq_len'],
allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False),
replication=replication_factor,
packing_ratio=dataloader_batch_size / dataset_batch_size,
**dataset_constructor_args,
)

else:
Expand Down Expand Up @@ -283,24 +295,19 @@ def build_finetuning_dataloader(
dataset_name_or_path,
)

# Build dataset from HF.
# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
k not in {'split', 'preprocessing_fn'}
}
streaming_dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=dataset_cfg.get('safe_load', False),
max_seq_len=dataset_cfg['max_seq_len'],
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
target_prompts=dataset_cfg.get(
'target_prompts',
_DEFAULT_TARGET_PROMPTS,
),
target_responses=dataset_cfg.get(
'target_responses',
_DEFAULT_TARGET_RESPONSES,
),
decoder_only_format=dataset_cfg['decoder_only_format'],
hf_kwargs=dataset_cfg.get('hf_kwargs', {}),
**dataset_constructor_args,
)

# Ensure dataset is large enough.
Expand Down Expand Up @@ -367,6 +374,7 @@ def _validate_config(
streams: Optional[Dict[str, Any]] = None,
target_prompts: Optional[str] = None,
target_responses: Optional[str] = None,
allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS,
**kwargs: Dict[str, Any],
) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -417,48 +425,18 @@ def _validate_config(
Defaults to "last", meaning only the final response in multi-turn examples
will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for
details.
allowed_dataset_keys (set[str], optional): The set of allowed keys for the dataset config.
kwargs (DictConfig, optional): Additional kwargs to
pass to `datasets.load_dataset`, which can be used to load
a dataset from local files.
Raises:
ValueError: If the dataset configuration does not meet the requirements.
"""
# Check for extraneous keys in the dataset config
allowed_additional_kwargs = {
'local',
'remote',
'split',
'download_retry',
'download_timeout',
'validate_hash',
'keep_zip',
'epoch_size',
'predownload',
'cache_limit',
'partition_algo',
'num_canonical_nodes',
'batch_size',
'shuffle',
'shuffle_algo',
'shuffle_seed',
'shuffle_block_size',
'sampling_method',
'sampling_granularity',
'batching_method',
'max_seq_len',
'allow_unsafe_types',
'replication',
'packing_ratio',
'allow_pad_trimming',
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
}
if not set(kwargs.keys()).issubset(allowed_additional_kwargs):
if not set(kwargs.keys()).issubset(allowed_dataset_keys):
raise ValueError(
'The dataset config contains the following extraneous keys: ' +\
', '.join(set(kwargs.keys()) - allowed_additional_kwargs),
', '.join(set(kwargs.keys()) - allowed_dataset_keys),
)

if hf_name is not None:
Expand Down Expand Up @@ -542,9 +520,9 @@ def _validate_config(
# Raise an error if the target_prompts + target_responses + decoder_only_format settings
# are invalid
if target_prompts is None:
target_prompts = _DEFAULT_TARGET_PROMPTS
target_prompts = DEFAULT_TARGET_PROMPTS
if target_responses is None:
target_responses = _DEFAULT_TARGET_RESPONSES
target_responses = DEFAULT_TARGET_RESPONSES
target_prompts, target_responses = target_prompts.lower(
), target_responses.lower()
validate_target_settings(
Expand Down Expand Up @@ -646,9 +624,9 @@ def build_collate_fn(
dataset_cfg = dataloader_cfg['dataset']
target_responses = dataset_cfg.get(
'target_responses',
_DEFAULT_TARGET_RESPONSES,
DEFAULT_TARGET_RESPONSES,
)
target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)
target_prompts = dataset_cfg.get('target_prompts', DEFAULT_TARGET_PROMPTS)
max_seq_len = dataset_cfg['max_seq_len']
decoder_only_format = dataset_cfg['decoder_only_format']
allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False)
Expand Down
33 changes: 24 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -117,6 +118,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
)
SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet']
HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata']
DEFAULT_TARGET_RESPONSES = 'last'
DEFAULT_TARGET_PROMPTS = 'none'

PromptResponseDict = Mapping[str, str]
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -800,14 +803,14 @@ def build_from_hf(
self,
dataset_name: str,
split: str,
safe_load: bool,
max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
tokenizer: PreTrainedTokenizerBase,
target_prompts: str,
target_responses: str,
decoder_only_format: bool,
hf_kwargs: Dict[str, Any],
safe_load: bool = False,
max_seq_len: int = 2048,
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
target_prompts: str = DEFAULT_TARGET_PROMPTS,
target_responses: str = DEFAULT_TARGET_RESPONSES,
decoder_only_format: bool = True,
hf_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Load a HuggingFace Datasets, preprocess, and tokenize.
Expand Down Expand Up @@ -846,6 +849,14 @@ def build_from_hf(
Returns:
Dataset: The tokenized dataset.
"""
if hf_kwargs is None:
hf_kwargs = {}

# None is checked in the function, because argument defaults were added after the function was written and we want
# to preserve the ordering of the arguments for backwards compatibility.
if tokenizer is None:
raise ValueError('A tokenizer must be provided.')

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
Expand Down Expand Up @@ -994,12 +1005,16 @@ def dataset_mapper(example: Dict):
assert filtered_dataset is not None
return filtered_dataset

@property
def streaming_dataset_class(self) -> Type[StreamingFinetuningDataset]:
return StreamingFinetuningDataset

def build_from_streaming(
self,
*args: Any,
**kwargs: Any,
) -> StreamingFinetuningDataset:
return StreamingFinetuningDataset(*args, **kwargs)
return self.streaming_dataset_class(*args, **kwargs)


dataset_constructor = DatasetConstructor()
Expand Down
23 changes: 6 additions & 17 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,6 @@ def build_inner_model(
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, # type: ignore
config, # type: ignore
*args, # type: ignore
**kwargs, # type: ignore
): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch,
)

set_config_overrides(config, config_overrides)

# We need to have all non-zero local ranks be not-pretrained
Expand All @@ -293,13 +276,16 @@ def _autoset_attn_implementation_monkeypatch(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=
requested_attention_implementation,
config=config,
)
else:
with init_empty_weights(include_buffers=False):
AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)

dist.barrier()
Expand All @@ -312,12 +298,14 @@ def _autoset_attn_implementation_monkeypatch(
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
attn_implementation=requested_attention_implementation,
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
elif resolved_init_device == 'meta':
if pretrained:
Expand All @@ -328,6 +316,7 @@ def _autoset_attn_implementation_monkeypatch(
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions scripts/eval/yamls/long_context_tasks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ icl_tasks:
icl_task_type: generation_task_with_answers
hf_loading_vars:
name: wikiqa
context_length: 2048
context_length: 4096
split: test
-
label: wikiqa_8k
Expand All @@ -114,7 +114,7 @@ icl_tasks:
icl_task_type: generation_task_with_answers
hf_loading_vars:
name: wikiqa
context_length: 2048
context_length: 8192
split: test
-
label: hotpotqa_beginning_2k
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24',
'mlflow>=2.14.1,<2.15',
'mlflow>=2.14.1,<2.16',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.8.0,<0.9',
Expand Down
Loading

0 comments on commit fc2aa31

Please sign in to comment.