From 7f7584a2fb20a62af65be1046b292b3d9e8cedb6 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 25 Sep 2023 10:35:14 -0700 Subject: [PATCH] Replace dashses with underscores to appease HF --- llmfoundry/data/finetuning/dataloader.py | 7 ++++++- llmfoundry/data/finetuning/tasks.py | 4 +++- tests/test_dataloader.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 661b1e808d..9688093714 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -299,6 +299,9 @@ def _build_hf_dataset_from_remote( 'downloaded_finetuning', cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not', ) + # HF datasets does not support a split with dashes, so we replace dashes + # with underscores in the destination split. + destination_split = cfg.dataset.split.replace('-', '_') os.makedirs(finetune_dir, exist_ok=True) for extension in supported_extensions: name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}' @@ -306,7 +309,9 @@ def _build_hf_dataset_from_remote( os.path.abspath( os.path.join( finetune_dir, 'data', - f'{cfg.dataset.split}-00000-of-00001.{extension}'))) + f'{destination_split}-00000-of-00001.{extension}'))) + print('HERE!!!', destination_split) + # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed') diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0a2b386048..f2bd0239c8 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -319,7 +319,9 @@ def build_from_hf( Dataset: The tokenized dataset. """ dataset_name = cfg.hf_name - split = cfg.split + # HF datasets does not support a split with dashes,so we replace split + # dashes with underscore. + split = cfg.split.replace('-', '_') kwargs = cfg.get('hf_kwargs', {}) proto_preprocessing_fn = cfg.get('preprocessing_fn') if isinstance(proto_preprocessing_fn, dict) or isinstance( diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index eea887d663..6495eccf65 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -398,7 +398,7 @@ def mock_get_file(path: str, destination: str, overwrite: bool = False): make_tiny_ft_dataset(path=destination, size=16) -@pytest.mark.parametrize('split', ['train', 'custom', 'data']) +@pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data']) def test_finetuning_dataloader_custom_split_remote( tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch): tokenizer_name = 'gpt2'