From 16f92ef5153f3d4339bb8a758c50d6044f37be41 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Tue, 3 Dec 2024 15:28:06 -0800 Subject: [PATCH] Bugfix auto packing with streams + no remote path (#1679) --- llmfoundry/data/packing.py | 7 ++-- tests/data/test_packing.py | 65 ++++++++++++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 5eacced549..6fed96a13d 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -474,7 +474,9 @@ def profile_packing( # If streaming dataset, use a temporary local folder for profiling local_rank_zero = dist.get_global_rank() - dist.get_local_rank() - if dataset_cfg.get('remote') is not None: + if dataset_cfg.get( + 'remote', + ) is not None and dataset_cfg.get('local') is None: tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] @@ -485,7 +487,8 @@ def profile_packing( tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - stream_config['local'] = tmp_path + if stream_config.get('local') is None: + stream_config['local'] = tmp_path # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index 48713f8a19..8402694672 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from typing import Any +from typing import Any, Callable from unittest.mock import Mock, patch import pytest @@ -161,27 +161,73 @@ def test_dist_auto_packing(profile_packing: Mock): assert packing_ratio == 2 +def get_remote_config( + base_cfg: dict, + remote_dir: str, + local_dir: str, +) -> DictConfig: + return DictConfig({ + **base_cfg, + 'dataset': { + **base_cfg['dataset'], + 'remote': remote_dir, + 'local': local_dir, + }, + }) + + +def get_streams_config( + base_cfg: dict, + remote_dir: str, + local_dir: str, +) -> DictConfig: + return DictConfig({ + **base_cfg, + 'dataset': { + **base_cfg['dataset'], + 'streams': { + 'stream_with_remote': { + 'remote': remote_dir, + 'local': local_dir, + }, + 'stream_without_remote': { + 'local': remote_dir, + }, + }, + }, + }) + + def patched_packing_ratio(*args: Any, **kwargs: Any): from llmfoundry.data.packing import auto_packing_ratio return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4) +@pytest.mark.parametrize( + 'get_config', + [ + get_remote_config, + get_streams_config, + ], +) @patch( 'llmfoundry.data.finetuning.dataloader.auto_packing_ratio', patched_packing_ratio, ) -def test_auto_packing_with_streaming_dataloader(tmp_path: Path): +def test_auto_packing_with_streaming_dataloader( + get_config: Callable[[dict, str, str], DictConfig], + tmp_path: Path, +): columns = {'prompt': 'str', 'response': 'str'} tokenizer = build_tokenizer('gpt2', {}) remote_dir = str(tmp_path / 'remote') local_dir = str(tmp_path / 'local') with MDSWriter(out=remote_dir, columns=columns, compression=None) as out: out.write({'prompt': 'HELLO', 'response': 'WORLD'}) - cfg = DictConfig({ + + base_cfg = { 'dataset': { - 'remote': remote_dir, - 'local': local_dir, 'packing_ratio': 'auto', 'max_seq_len': 200, 'decoder_only_format': True, @@ -194,7 +240,9 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): 'prefetch_factor': None, 'persistent_workers': False, 'timeout': 0, - }) + } + + cfg = get_config(base_cfg, remote_dir, local_dir) loader = build_finetuning_dataloader( **cfg, @@ -214,7 +262,10 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): assert isinstance(loader.batch_size, int) assert loader.dataset.packing_ratio == int(loader.batch_size / 6) - state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False) + state_dict = loader.dataset.state_dict( + num_samples=2, + from_beginning=False, + ) assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio