Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a temporary directory for downloading finetuning dataset files #1608

Merged
merged 15 commits into from
Oct 24, 2024
9 changes: 7 additions & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from llmfoundry.data.finetuning.tasks import (
DEFAULT_TARGET_PROMPTS,
DEFAULT_TARGET_RESPONSES,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor,
)
Expand All @@ -32,6 +31,7 @@
MissingHuggingFaceURLSplitError,
NotEnoughDatasetSamplesError,
)
from llmfoundry.utils.file_utils import dist_mkdtemp
from llmfoundry.utils.registry_utils import construct_from_registry

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -571,7 +571,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
# HF datasets does not support a split with dashes, so we replace dashes with underscores.
hf_formatted_split = split.replace('-', '_')
finetune_dir = os.path.join(
DOWNLOADED_FT_DATASETS_DIRPATH,
dist_mkdtemp(),
hf_formatted_split if hf_formatted_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
Expand All @@ -593,6 +593,8 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
finetune_dir,
f'.node_{dist.get_node_rank()}_local_rank0_completed',
)

log.debug(f'Downloading dataset {name} to {destination}.')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand All @@ -615,10 +617,13 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

print(dist.get_local_rank(), f'signal_file_path: {signal_file_path}')
irenedea marked this conversation as resolved.
Show resolved Hide resolved
irenedea marked this conversation as resolved.
Show resolved Hide resolved

# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished trying to download the dataset
print('GOT TO BARRIER')
irenedea marked this conversation as resolved.
Show resolved Hide resolved
dist.barrier()

# clean up signal file
Expand Down
16 changes: 6 additions & 10 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import importlib
import logging
import os
import tempfile
import warnings
from collections.abc import Mapping
from functools import partial
Expand Down Expand Up @@ -107,15 +108,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(
os.path.realpath(__file__),
os.pardir,
os.pardir,
os.pardir,
'.downloaded_finetuning',
),
)
SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet']
HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata']
DEFAULT_TARGET_RESPONSES = 'last'
Expand Down Expand Up @@ -920,10 +912,14 @@ def build_from_hf(
if not os.path.isdir(dataset_name):
# dataset_name is not a local dir path, download if needed.
local_dataset_dir = os.path.join(
DOWNLOADED_FT_DATASETS_DIRPATH,
tempfile.mkdtemp(),
dataset_name,
)

log.debug(
f'Downloading dataset {dataset_name} to {local_dataset_dir}.',
)
irenedea marked this conversation as resolved.
Show resolved Hide resolved

if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
# Safely load a dataset from HF Hub with restricted file types.
hf_hub.snapshot_download(
Expand Down
21 changes: 21 additions & 0 deletions llmfoundry/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import tempfile

from composer.utils import dist


def dist_mkdtemp() -> str:
"""Creates a temp directory on local rank 0 to use for other ranks.

Returns:
str: The path to the temporary directory.
"""
tempdir = None
if dist.get_local_rank() == 0:
tempdir = tempfile.mkdtemp()
tempdir = dist.all_gather_object(tempdir)[0]
irenedea marked this conversation as resolved.
Show resolved Hide resolved
if tempdir is None:
raise RuntimeError('Dist operation to get tempdir failed.')
return tempdir
19 changes: 9 additions & 10 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
validate_target_settings,
)
from llmfoundry.data.finetuning.tasks import (
DOWNLOADED_FT_DATASETS_DIRPATH,
HUGGINGFACE_FOLDER_EXTENSIONS,
SUPPORTED_EXTENSIONS,
dataset_constructor,
Expand Down Expand Up @@ -430,9 +429,9 @@ def test_finetuning_dataloader_safe_load(
hf_name: str,
hf_revision: Optional[str],
expectation: ContextManager,
tmp_path: pathlib.Path,
):
# Clear the folder
shutil.rmtree(DOWNLOADED_FT_DATASETS_DIRPATH, ignore_errors=True)
cfg = DictConfig({
'dataset': {
'hf_name': hf_name,
Expand All @@ -455,18 +454,18 @@ def test_finetuning_dataloader_safe_load(

tokenizer = build_tokenizer('gpt2', {})

with expectation:
_ = build_finetuning_dataloader(
tokenizer=tokenizer,
device_batch_size=1,
**cfg,
)
with patch('llmfoundry.data.finetuning.tasks.tempfile.mkdtemp', return_value=str(tmp_path)):
with expectation:
_ = build_finetuning_dataloader(
tokenizer=tokenizer,
device_batch_size=1,
**cfg,
)

# If no raised errors, we should expect downloaded files with only safe file types.
if isinstance(expectation, does_not_raise):
download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name)
downloaded_files = [
file for _, _, files in os.walk(download_dir) for file in files
file for _, _, files in os.walk(tmp_path) for file in files
]
assert len(downloaded_files) > 0
assert all(
Expand Down
Loading