Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Use a temporary directory for downloading finetuning dataset … #1637

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llmfoundry.data.finetuning.tasks import (
DEFAULT_TARGET_PROMPTS,
DEFAULT_TARGET_RESPONSES,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor,
)
Expand All @@ -32,7 +33,6 @@
MissingHuggingFaceURLSplitError,
NotEnoughDatasetSamplesError,
)
from llmfoundry.utils.file_utils import dist_mkdtemp
from llmfoundry.utils.registry_utils import construct_from_registry

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -569,7 +569,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
# HF datasets does not support a split with dashes, so we replace dashes with underscores.
hf_formatted_split = split.replace('-', '_')
finetune_dir = os.path.join(
dist_mkdtemp(),
DOWNLOADED_FT_DATASETS_DIRPATH,
hf_formatted_split if hf_formatted_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
Expand All @@ -591,8 +591,6 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
finetune_dir,
f'.node_{dist.get_node_rank()}_local_rank0_completed',
)

log.debug(f'Downloading dataset {name} to {destination}.')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand Down
16 changes: 10 additions & 6 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import importlib
import logging
import os
import tempfile
import warnings
from collections.abc import Mapping
from functools import partial
Expand Down Expand Up @@ -108,6 +107,15 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(
os.path.realpath(__file__),
os.pardir,
os.pardir,
os.pardir,
'.downloaded_finetuning',
),
)
SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet']
HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata']
DEFAULT_TARGET_RESPONSES = 'last'
Expand Down Expand Up @@ -913,14 +921,10 @@ def build_from_hf(
if not os.path.isdir(dataset_name):
# dataset_name is not a local dir path, download if needed.
local_dataset_dir = os.path.join(
tempfile.mkdtemp(),
DOWNLOADED_FT_DATASETS_DIRPATH,
dataset_name,
)

log.debug(
f'Downloading dataset {dataset_name} to {local_dataset_dir}.',
)

if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
# Safely load a dataset from HF Hub with restricted file types.
hf_hub.snapshot_download(
Expand Down
24 changes: 0 additions & 24 deletions llmfoundry/utils/file_utils.py

This file was deleted.

19 changes: 10 additions & 9 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
validate_target_settings,
)
from llmfoundry.data.finetuning.tasks import (
DOWNLOADED_FT_DATASETS_DIRPATH,
HUGGINGFACE_FOLDER_EXTENSIONS,
SUPPORTED_EXTENSIONS,
dataset_constructor,
Expand Down Expand Up @@ -430,9 +431,9 @@ def test_finetuning_dataloader_safe_load(
hf_name: str,
hf_revision: Optional[str],
expectation: ContextManager,
tmp_path: pathlib.Path,
):
# Clear the folder
shutil.rmtree(DOWNLOADED_FT_DATASETS_DIRPATH, ignore_errors=True)
cfg = DictConfig({
'dataset': {
'hf_name': hf_name,
Expand All @@ -455,18 +456,18 @@ def test_finetuning_dataloader_safe_load(

tokenizer = build_tokenizer('gpt2', {})

with patch('llmfoundry.data.finetuning.tasks.tempfile.mkdtemp', return_value=str(tmp_path)):
with expectation:
_ = build_finetuning_dataloader(
tokenizer=tokenizer,
device_batch_size=1,
**cfg,
)
with expectation:
_ = build_finetuning_dataloader(
tokenizer=tokenizer,
device_batch_size=1,
**cfg,
)

# If no raised errors, we should expect downloaded files with only safe file types.
if isinstance(expectation, does_not_raise):
download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name)
downloaded_files = [
file for _, _, files in os.walk(tmp_path) for file in files
file for _, _, files in os.walk(download_dir) for file in files
]
assert len(downloaded_files) > 0
assert all(
Expand Down
Loading