Skip to content

Commit

Permalink
Fix streaming dataset setup for fine-tuning
Browse files Browse the repository at this point in the history
Currently, when a StreamingFinetuningDataset is created using the
build_finetuning_dataloader method, a failure is returned as some
incorrect parameters are passed through to the constructor of
StreamingFinetuningDataset. This patch fixes the paramter mismatch and
adds test coverage for this case.
  • Loading branch information
boomanaiden154 committed Sep 15, 2023
1 parent 7ec2fe0 commit d3d2716
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
6 changes: 0 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
keep_zip=cfg.dataset.get('keep_zip', False),
epoch_size=cfg.dataset.get('epoch_size', None),
predownload=cfg.dataset.get('predownload', None),
cache_limit=cfg.dataset.get('cache_limit', None),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
shuffle=cfg.dataset.get('shuffle', False),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
Expand Down
59 changes: 59 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from composer.utils import dist, using_torch_2
from omegaconf import OmegaConf as om
from streaming import MDSWriter

from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
Expand Down Expand Up @@ -42,6 +43,25 @@ def get_abs_data_path(data_local: str):
return os.path.join(os.getcwd(), data_local)


def build_mock_ft_streaming_dataset(data_path, split):
columns = {'prompt': 'str', 'response': 'str'}

dataset = [{
'prompt': 'This is just a test1',
'response': 'Hello World1'
}, {
'prompt': 'This is just a test2',
'response': 'Hello world2'
}]

output_path = os.path.join(data_path, split)

with MDSWriter(columns=columns, out=output_path,
compression=None) as output_writer:
for sample in dataset:
output_writer.write(sample)


@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@pytest.mark.parametrize('pretokenize', [False, True])
def test_correct_padding(tokenizer_name: str,
Expand Down Expand Up @@ -414,6 +434,45 @@ def test_finetuning_dataloader_custom_split_remote(
_ = build_finetuning_dataloader(cfg, tokenizer, 4)


def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
tokenizer_name = 'gpt2'
max_seq_len = 2048

remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

build_mock_ft_streaming_dataset(remote_path, 'train')

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

_ = build_finetuning_dataloader(cfg, tokenizer, 4)


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
Expand Down

0 comments on commit d3d2716

Please sign in to comment.