Skip to content

Commit

Permalink
Replace dashses with underscores to appease HF
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Sep 25, 2023
1 parent 6883562 commit 7f7584a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
7 changes: 6 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,19 @@ def _build_hf_dataset_from_remote(
'downloaded_finetuning',
cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not',
)
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
destination = str(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{cfg.dataset.split}-00000-of-00001.{extension}')))
f'{destination_split}-00000-of-00001.{extension}')))
print('HERE!!!', destination_split)

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def build_from_hf(
Dataset: The tokenized dataset.
"""
dataset_name = cfg.hf_name
split = cfg.split
# HF datasets does not support a split with dashes,so we replace split
# dashes with underscore.
split = cfg.split.replace('-', '_')
kwargs = cfg.get('hf_kwargs', {})
proto_preprocessing_fn = cfg.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, dict) or isinstance(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def mock_get_file(path: str, destination: str, overwrite: bool = False):
make_tiny_ft_dataset(path=destination, size=16)


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
@pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data'])
def test_finetuning_dataloader_custom_split_remote(
tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch):
tokenizer_name = 'gpt2'
Expand Down

0 comments on commit 7f7584a

Please sign in to comment.