From d3d2716387f9a30292dd0df9cc4eb54482691afd Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Fri, 15 Sep 2023 12:03:01 -0700 Subject: [PATCH] Fix streaming dataset setup for fine-tuning Currently, when a StreamingFinetuningDataset is created using the build_finetuning_dataloader method, a failure is returned as some incorrect parameters are passed through to the constructor of StreamingFinetuningDataset. This patch fixes the paramter mismatch and adds test coverage for this case. --- llmfoundry/data/finetuning/dataloader.py | 6 --- tests/test_dataloader.py | 59 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index a009f13660..6b1562c37f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -125,17 +125,11 @@ def build_finetuning_dataloader(cfg: DictConfig, download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), - epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), - cache_limit=cfg.dataset.get('cache_limit', None), - partition_algo=cfg.dataset.get('partition_algo', 'orig'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), - shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18), - sampling_method=cfg.dataset.get('sampling_method', 'balanced'), ) collate_fn, dataloader_batch_size = _build_collate_fn( diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 53549ccfe1..314c904676 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -13,6 +13,7 @@ import torch from composer.utils import dist, using_torch_2 from omegaconf import OmegaConf as om +from streaming import MDSWriter from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) @@ -42,6 +43,25 @@ def get_abs_data_path(data_local: str): return os.path.join(os.getcwd(), data_local) +def build_mock_ft_streaming_dataset(data_path, split): + columns = {'prompt': 'str', 'response': 'str'} + + dataset = [{ + 'prompt': 'This is just a test1', + 'response': 'Hello World1' + }, { + 'prompt': 'This is just a test2', + 'response': 'Hello world2' + }] + + output_path = os.path.join(data_path, split) + + with MDSWriter(columns=columns, out=output_path, + compression=None) as output_writer: + for sample in dataset: + output_writer.write(sample) + + @pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m']) @pytest.mark.parametrize('pretokenize', [False, True]) def test_correct_padding(tokenizer_name: str, @@ -414,6 +434,45 @@ def test_finetuning_dataloader_custom_split_remote( _ = build_finetuning_dataloader(cfg, tokenizer, 4) +def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path): + tokenizer_name = 'gpt2' + max_seq_len = 2048 + + remote_path = os.path.join(tmp_path, 'remote') + local_path = os.path.join(tmp_path, 'local') + + build_mock_ft_streaming_dataset(remote_path, 'train') + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'remote': remote_path, + 'local': local_path, + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + _ = build_finetuning_dataloader(cfg, tokenizer, 4) + + @pytest.mark.parametrize('add_bad_data_dropped', [True, False]) @pytest.mark.parametrize('add_bad_data_error', [True, False]) def test_malformed_data(