diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 639beba6f0..bbde8c4629 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -233,6 +233,7 @@ def build_finetuning_dataloader( max_seq_len=dataset_cfg['max_seq_len'], allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, + packing_ratio=dataloader_batch_size / dataset_batch_size, ) else: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 40f178fb6e..67102200cb 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -592,6 +592,7 @@ def __init__( max_seq_len: int = 2048, allow_unsafe_types: bool = False, replication: Optional[int] = None, + packing_ratio: Optional[float] = None, **kwargs: Any, ): @@ -644,6 +645,7 @@ def __init__( self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.packing_ratio = packing_ratio # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -675,6 +677,15 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return {'turns': [sample]} return tokenize_formatted_example(sample, tokenizer=self.tokenizer) + def state_dict(self, num_samples: int, + from_beginning: bool) -> Dict[str, Any]: + if self.packing_ratio is not None: + num_samples = self.packing_ratio * num_samples + + return super().state_dict( + num_samples=num_samples, from_beginning=from_beginning + ) + class DatasetConstructor: