Skip to content

Commit

Permalink
Fix packing + streaming + resumption (#1277)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 14, 2024
1 parent 4350990 commit dbd798e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def build_finetuning_dataloader(
cache_limit=dataset_cfg.get('cache_limit', None),
partition_algo=dataset_cfg.get('partition_algo', 'relaxed'),
num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None),
batch_size=dataset_batch_size,
batch_size=dataloader_batch_size,
shuffle=dataset_cfg.get('shuffle', False),
shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'),
shuffle_seed=dataset_cfg.get('shuffle_seed', 9176),
Expand All @@ -233,6 +233,7 @@ def build_finetuning_dataloader(
max_seq_len=dataset_cfg['max_seq_len'],
allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False),
replication=replication_factor,
packing_ratio=dataloader_batch_size / dataset_batch_size,
)

else:
Expand Down Expand Up @@ -390,6 +391,7 @@ def _validate_config(
'allow_pad_trimming',
'seq_parallel_replication',
'auto_packing_replication',
'max_leftover_bins_to_keep',
}
if not set(kwargs.keys()).issubset(allowed_additional_kwargs):
raise ValueError(
Expand Down
12 changes: 12 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def __init__(
max_seq_len: int = 2048,
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
packing_ratio: Optional[float] = None,
**kwargs: Any,
):

Expand Down Expand Up @@ -644,6 +645,7 @@ def __init__(

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.packing_ratio = packing_ratio

# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
Expand Down Expand Up @@ -675,6 +677,16 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
return {'turns': [sample]}
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)

def state_dict(self, num_samples: int,
from_beginning: bool) -> Dict[str, Any]:
if self.packing_ratio is not None:
num_samples = int(self.packing_ratio * num_samples)

return super().state_dict(
num_samples=num_samples,
from_beginning=from_beginning,
)


class DatasetConstructor:

Expand Down
10 changes: 10 additions & 0 deletions tests/data/test_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.utils.data import DataLoader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.utils.builders import build_tokenizer

Expand Down Expand Up @@ -206,6 +207,15 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
if batch_ix >= 3:
break

assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingFinetuningDataset)
assert loader.dataset.packing_ratio is not None
assert isinstance(loader.batch_size, int)
assert loader.dataset.packing_ratio == int(loader.batch_size / 6)

state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False)
assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio


@pytest.mark.parametrize('packing_ratio', ['auto', 2.0])
@patch(
Expand Down

0 comments on commit dbd798e

Please sign in to comment.