diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 661b1e808d..5d4dfdbf85 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -293,11 +293,14 @@ def _build_hf_dataset_from_remote( FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions. """ supported_extensions = ['jsonl', 'csv', 'parquet'] + # HF datasets does not support a split with dashes, so we replace dashes + # with underscores in the destination split. + destination_split = cfg.dataset.split.replace('-', '_') finetune_dir = os.path.join( os.path.dirname( os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), 'downloaded_finetuning', - cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not', + destination_split if destination_split != 'data' else 'data_not', ) os.makedirs(finetune_dir, exist_ok=True) for extension in supported_extensions: @@ -306,7 +309,8 @@ def _build_hf_dataset_from_remote( os.path.abspath( os.path.join( finetune_dir, 'data', - f'{cfg.dataset.split}-00000-of-00001.{extension}'))) + f'{destination_split}-00000-of-00001.{extension}'))) + # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed') diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0a2b386048..f2bd0239c8 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -319,7 +319,9 @@ def build_from_hf( Dataset: The tokenized dataset. """ dataset_name = cfg.hf_name - split = cfg.split + # HF datasets does not support a split with dashes,so we replace split + # dashes with underscore. + split = cfg.split.replace('-', '_') kwargs = cfg.get('hf_kwargs', {}) proto_preprocessing_fn = cfg.get('preprocessing_fn') if isinstance(proto_preprocessing_fn, dict) or isinstance( diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index eea887d663..6495eccf65 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -398,7 +398,7 @@ def mock_get_file(path: str, destination: str, overwrite: bool = False): make_tiny_ft_dataset(path=destination, size=16) -@pytest.mark.parametrize('split', ['train', 'custom', 'data']) +@pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data']) def test_finetuning_dataloader_custom_split_remote( tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch): tokenizer_name = 'gpt2'