Skip to content

Commit

Permalink
Merge branch 'main' into expand_gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
sashaDoubov authored Sep 26, 2023
2 parents c2b10fd + f488ad5 commit 9d56430
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ def _build_hf_dataset_from_remote(
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not',
destination_split if destination_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
Expand All @@ -306,7 +309,8 @@ def _build_hf_dataset_from_remote(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{cfg.dataset.split}-00000-of-00001.{extension}')))
f'{destination_split}-00000-of-00001.{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def build_from_hf(
Dataset: The tokenized dataset.
"""
dataset_name = cfg.hf_name
split = cfg.split
# HF datasets does not support a split with dashes,so we replace split
# dashes with underscore.
split = cfg.split.replace('-', '_')
kwargs = cfg.get('hf_kwargs', {})
proto_preprocessing_fn = cfg.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, dict) or isinstance(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def mock_get_file(path: str, destination: str, overwrite: bool = False):
make_tiny_ft_dataset(path=destination, size=16)


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
@pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data'])
def test_finetuning_dataloader_custom_split_remote(
tmp_path: pathlib.Path, split: str, monkeypatch: pytest.MonkeyPatch):
tokenizer_name = 'gpt2'
Expand Down

0 comments on commit 9d56430

Please sign in to comment.