diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py index 42af7e9375..264e79a2e1 100644 --- a/llmfoundry/data/finetuning/collator.py +++ b/llmfoundry/data/finetuning/collator.py @@ -224,6 +224,10 @@ class Seq2SeqFinetuningCollator: sizes. Default: ``False`` ensures that all sequences are max_seq_len. batch_metadata (dict, optional): A dictionary of metadata which will be added to the batch. + pad_to_longest (bool, optional): Whether to pad to the longest sequence, + which may result in smaller but inconsistent batch sizes. This is + primarily used to profile packing. + Default: ``False`` ensures that all sequences are max_seq_len. """ def __init__( @@ -235,6 +239,7 @@ def __init__( target_prompts: str = 'none', allow_pad_trimming: bool = False, batch_metadata: Optional[Dict[str, Any]] = None, + pad_to_longest: bool = False, ): self.tokenizer = tokenizer self.max_seq_len = max_seq_len @@ -247,6 +252,8 @@ def __init__( self._allow_pad_trimming = allow_pad_trimming self._seen_first_batch = False + self._pad_to_longest = pad_to_longest + illegal_keys = [ 'input_ids', 'labels', @@ -320,24 +327,29 @@ def _process_and_batch_decoder_only( ) -> Dict[str, torch.Tensor]: # Steps explained in comments processed_examples = [] - for example in examples: - input_ids, labels = stitch_turns_decoder_only( - example_turns=example['turns'], - target_prompts=self.target_prompts, - target_responses=self.target_responses, - eos_token_id=self.tokenizer.eos_token_id, - ) + input_ids_and_labels = [stitch_turns_decoder_only( + example_turns=example['turns'], + target_prompts=self.target_prompts, + target_responses=self.target_responses, + eos_token_id=self.tokenizer.eos_token_id, + ) for example in examples] + + if self._pad_to_longest: + max_seq_len = max([len(input_ids) for input_ids, _ in input_ids_and_labels]) + else: + max_seq_len = self.max_seq_len + for input_ids, labels in input_ids_and_labels: orig_size = len(input_ids) # We may need to truncate the input_ids / labels in order to maintain max_seq_len - if orig_size > self.max_seq_len: - input_ids = input_ids[:self.max_seq_len] - labels = labels[:self.max_seq_len] + if orig_size > max_seq_len: + input_ids = input_ids[:max_seq_len] + labels = labels[:max_seq_len] # Check to make sure there are still loss-generating tokens. Error if not. if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0: raise ValueError( - f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' +\ + f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\ f'Pre-truncation sequence length was {orig_size}. ' +\ 'This sample should have been filtered out before reaching the collator. If using ' +\ 'pre-tokenized streaming data, this may have resulted from using different ' +\ @@ -348,7 +360,7 @@ def _process_and_batch_decoder_only( # Still issue a warning when truncating if not self._warned_truncated: warnings.warn( - f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' +\ + f'Truncating sequence of length={orig_size} to fit max_seq_len={max_seq_len}. ' +\ f'If truncation is a problem, consider increasing max_seq_len.', ) self._warned_truncated = True @@ -358,7 +370,7 @@ def _process_and_batch_decoder_only( # Annoyingly, we need to pad everything but input_ids # and attention_mask ourselves n_total = len(input_ids) - i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total) + i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total) if self.tokenizer.padding_side == 'left': labels = i_pad + labels else: @@ -376,7 +388,7 @@ def _process_and_batch_decoder_only( batch = self.tokenizer.pad( processed_examples, padding='max_length', - max_length=self.max_seq_len, + max_length=max_seq_len, return_tensors='pt', ) @@ -410,35 +422,42 @@ def _process_and_batch_encoder_decoder( # The encoder-decoder case is has some gotchas. # Steps are explained in comments. processed_examples = [] - for example in examples: - context, target = stitch_turns_encoder_decoder( - example_turns=example['turns'], - eos_token_id=self.tokenizer.eos_token_id, - ) + contexts_and_targets = [stitch_turns_encoder_decoder( + example_turns=example['turns'], + eos_token_id=self.tokenizer.eos_token_id, + ) for example in examples] + + if self._pad_to_longest: + max_seq_len = 0 + for context, target in contexts_and_targets: + max_seq_len = max(max_seq_len, len(context), len(target)) + else: + max_seq_len = self.max_seq_len + for context, target in contexts_and_targets: # We need to pad labels ourselves. Because HF. - if len(target) < self.max_seq_len: - i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target)) + if len(target) < max_seq_len: + i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target)) target = target + i_pad else: if not self._warned_target: warnings.warn( f'Truncating TARGET sequence of length={len(target)} ' +\ - f'to max_seq_len={self.max_seq_len}. If truncation is ' +\ + f'to max_seq_len={max_seq_len}. If truncation is ' +\ f'a problem, consider increasing max_seq_len.') self._warned_target = True - target = target[:self.max_seq_len - + target = target[:max_seq_len - 1] + [self.tokenizer.eos_token_id] # We might need to truncate the context. Preserve the beginning. - if len(context) > self.max_seq_len: + if len(context) > max_seq_len: if not self._warned_context: warnings.warn( f'Truncating CONTEXT sequence of length={len(context)} ' +\ - f'to max_seq_len={self.max_seq_len}. If truncation is ' +\ + f'to max_seq_len={max_seq_len}. If truncation is ' +\ f'a problem, consider increasing max_seq_len.') self._warned_context = True - context = context[:self.max_seq_len - + context = context[:max_seq_len - 1] + [self.tokenizer.eos_token_id] # Back into the example @@ -454,7 +473,7 @@ def _process_and_batch_encoder_decoder( batch = self.tokenizer.pad( processed_examples, padding='max_length', - max_length=self.max_seq_len, + max_length=max_seq_len, return_tensors='pt', ) # We're still missing decoder_input_ids and decoder_attention_mask diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 11104ac706..41fba1e607 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -392,6 +392,7 @@ def _validate_config( 'seq_parallel_replication', 'auto_packing_replication', 'max_leftover_bins_to_keep', + 'pad_to_longest', } if not set(kwargs.keys()).issubset(allowed_additional_kwargs): raise ValueError( @@ -590,6 +591,7 @@ def build_collate_fn( max_seq_len = dataset_cfg['max_seq_len'] decoder_only_format = dataset_cfg['decoder_only_format'] allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) + pad_to_longest = dataset_cfg.get('pad_to_longest', False) collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, @@ -598,6 +600,7 @@ def build_collate_fn( target_responses=target_responses, target_prompts=target_prompts, allow_pad_trimming=allow_pad_trimming, + pad_to_longest=pad_to_longest, ) packing_ratio = dataset_cfg.get('packing_ratio') diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a6fdf34953..2552e96386 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -20,7 +20,17 @@ class BinPackCollator: - """Utility collator for packing to reduce padding.""" + """Utility collator for packing to reduce padding. + + Args: + collator (Callable): The base collator to use. + target_batch_size (int): The number of bins. + max_seq_len(int): The maximum sequence length of a bin. + pad_token_id (int): The padding token id. + padding_side (Literal['left', 'right']): The side to pad on. + max_leftover_bins_to_keep (Optional[int]): The number of leftover bins to keep. + is_profiling (bool): Whether the collator is being used for profiling. + """ def __init__( self, @@ -30,6 +40,7 @@ def __init__( pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int] = None, + is_profiling: bool = False, ): self.base_collator = collator self.out_size = int(target_batch_size) @@ -56,6 +67,8 @@ def __init__( self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = [] + self._is_profiling = is_profiling + @property def waste(self) -> float: return 1 - (self.n_packed_tokens / self.n_total_tokens) @@ -86,13 +99,15 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ] # Cut everything down to size sizes, trimmed_examples = _trim_batch(batch) - return self._pack_trimmed_examples(trimmed_examples, sizes) + batch = self._pack_trimmed_examples(trimmed_examples, sizes) + assert batch is not None + return batch def _pack_trimmed_examples( self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int], - ) -> Dict[str, torch.Tensor]: + ) -> Optional[Dict[str, torch.Tensor]]: """Packs trimmed examples into fixed-size bins and repads them. Args: @@ -103,7 +118,7 @@ def _pack_trimmed_examples( Dict[str, torch.Tensor]: A batch of repadded examples ready for processing """ # Apply our CS 101 bin packing algorithm. - packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing( + packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = self._first_fit_bin_packing( sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, @@ -115,15 +130,136 @@ def _pack_trimmed_examples( self.n_packed_examples += self.out_size self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] + if self._is_profiling: + return None + # Re-pad to max_seq_len and batch - batch = _repad( - packed_examples, - max_seq_len=self.max_seq_len, - pad_token_id=self.pad_token_id, - padding_side=self.padding_side, - ) + batch = self._convert_to_batch(packed_examples) return batch + def _convert_to_batch( + self, + packed_examples: List[Dict[str, torch.Tensor]], + ) -> Dict[str, torch.Tensor]: + + pad_vals = { + 'input_ids': self.pad_token_id, + 'labels': -100, + 'attention_mask': 0, + 'sequence_id': -1, + } + keys = packed_examples[0].keys() + batch = {} + for key in keys: + batch[key] = torch.stack([ + _pad_tensor(example[key], pad_vals[key], self.max_seq_len, self.padding_side) + for example in packed_examples + ]) + return batch + + def _first_fit_bin_packing( + self, + sizes: List[int], + examples: List[Dict[str, torch.Tensor]], + num_bins: int, + max_bin_size: int, + existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]], + ) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ + str, torch.Tensor]]]]: + + # Will contain tuples (bin_size_size, packed_example) + bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins + + starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins]) + + sizes_and_examples = list(zip(sizes, examples)) + sorted_sizes_and_examples = sorted( + sizes_and_examples, + key=lambda x: x[0], + reverse=True, + ) + + required_num_examples = max(0, num_bins - len(bins)) + num_examples = len(sizes) + if num_examples < required_num_examples: + for size, example in sorted_sizes_and_examples: + # Can't keep packing. All remaining items get their own bin. + bins.append((size, example)) + + total_bin_sizes = sum([bin_size for bin_size, _ in bins]) + total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes + total_example_sizes = sum(sizes) + if total_new_bin_sizes != total_example_sizes: + raise AssertionError( + f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', + ) + + sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) + bin_sizes, packed_examples = [], [] + for bin_size, packed_example in sorted_bins: + bin_sizes.append(bin_size) + packed_examples.append(packed_example) + + # Return: + # - the num_bins largest packed examples + # - the total tokens in those examples + # - the total size of all new examples + # - leftover bins + return packed_examples[:num_bins], sum( + bin_sizes[:num_bins], + ), sum(sizes), sorted_bins[num_bins:] + + # Go through each item from longest to shortest. + # Note: all items will either go into an existing or new bin. + for i, (size, example) in enumerate(sorted_sizes_and_examples): + # If we can't keep packing, all remaining items get their own bin. + required_num_examples = max(0, num_bins - len(bins)) + n_remaining = num_examples - i + assert n_remaining >= required_num_examples + if n_remaining == required_num_examples: + # Can't keep packing. All remaining items get their own bin. + bins.append((size, example)) + continue + + # Add it to the first bin it fits in + added = False + for bidx in range(len(bins)): + if bins[bidx][0] + size <= max_bin_size: + bin_size, packed_example = bins.pop(bidx) + bin_size = bin_size + size + if not self._is_profiling: + packed_example = _combine_in_place(packed_example, example) + + bins.append((bin_size, packed_example)) + added = True + break + # If it didn't fit anywhere, open a new bin + if not added: + bins.append((size, example)) + + total_bin_sizes = sum([bin_size for bin_size, _ in bins]) + total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes + total_example_sizes = sum(sizes) + if total_new_bin_sizes != total_example_sizes: + raise AssertionError( + f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', + ) + + sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) + bin_sizes, packed_examples = [], [] + for bin_size, packed_example in sorted_bins: + bin_sizes.append(bin_size) + packed_examples.append(packed_example) + + # Return: + # - the num_bins largest packed examples + # - the total tokens in those examples + # - the total size of all new examples + # - leftover bins + return packed_examples[:num_bins], sum( + bin_sizes[:num_bins], + ), sum(sizes), sorted_bins[num_bins:] + def _trim_batch( batch: Dict[str, torch.Tensor], @@ -176,144 +312,20 @@ def _combine_in_place( example[k] = torch.cat([example[k], add_on[k]]) return example - -def _first_fit_bin_packing( - sizes: List[int], - examples: List[Dict[str, torch.Tensor]], - num_bins: int, - max_bin_size: int, - existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]], -) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ - str, torch.Tensor]]]]: - - # Will contain tuples (bin_size_size, packed_example) - bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins - - starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins]) - - sizes_and_examples = list(zip(sizes, examples)) - sorted_sizes_and_examples = sorted( - sizes_and_examples, - key=lambda x: x[0], - reverse=True, - ) - - required_num_examples = max(0, num_bins - len(bins)) - num_examples = len(sizes) - if num_examples < required_num_examples: - for size, example in sorted_sizes_and_examples: - # Can't keep packing. All remaining items get their own bin. - bins.append((size, example)) - - total_bin_sizes = sum([bin_size for bin_size, _ in bins]) - total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes - total_example_sizes = sum(sizes) - if total_new_bin_sizes != total_example_sizes: - raise AssertionError( - f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', - ) - - sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) - bin_sizes, packed_examples = [], [] - for bin_size, packed_example in sorted_bins: - bin_sizes.append(bin_size) - packed_examples.append(packed_example) - - # Return: - # - the num_bins largest packed examples - # - the total tokens in those examples - # - the total size of all new examples - # - leftover bins - return packed_examples[:num_bins], sum( - bin_sizes[:num_bins], - ), sum(sizes), sorted_bins[num_bins:] - - # Go through each item from longest to shortest. - # Note: all items will either go into an existing or new bin. - for i, (size, example) in enumerate(sorted_sizes_and_examples): - # If we can't keep packing, all remaining items get their own bin. - required_num_examples = max(0, num_bins - len(bins)) - n_remaining = num_examples - i - assert n_remaining >= required_num_examples - if n_remaining == required_num_examples: - # Can't keep packing. All remaining items get their own bin. - bins.append((size, example)) - continue - - # Add it to the first bin it fits in - added = False - for bidx in range(len(bins)): - if bins[bidx][0] + size <= max_bin_size: - bin_size, packed_example = bins.pop(bidx) - bin_size = bin_size + size - packed_example = _combine_in_place(packed_example, example) - bins.append((bin_size, packed_example)) - added = True - break - # If it didn't fit anywhere, open a new bin - if not added: - bins.append((size, example)) - - total_bin_sizes = sum([bin_size for bin_size, _ in bins]) - total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes - total_example_sizes = sum(sizes) - if total_new_bin_sizes != total_example_sizes: - raise AssertionError( - f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', - ) - - sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) - bin_sizes, packed_examples = [], [] - for bin_size, packed_example in sorted_bins: - bin_sizes.append(bin_size) - packed_examples.append(packed_example) - - # Return: - # - the num_bins largest packed examples - # - the total tokens in those examples - # - the total size of all new examples - # - leftover bins - return packed_examples[:num_bins], sum( - bin_sizes[:num_bins], - ), sum(sizes), sorted_bins[num_bins:] - - -def _repad( - packed_examples: List[Dict[str, torch.Tensor]], - max_seq_len: int, - pad_token_id: int, - padding_side: str, -) -> Dict[str, torch.Tensor]: - - def pad_tensor(tensor: torch.Tensor, pad_value: int): - if len(tensor) == max_seq_len: - return tensor - t = torch.full((max_seq_len,), - pad_value, - dtype=tensor.dtype, - device=tensor.device) - if padding_side == 'left': - t[-len(tensor):] = tensor - elif padding_side == 'right': - t[:len(tensor)] = tensor - else: - raise ValueError(f'Unknown {padding_side=}') - return t - - pad_vals = { - 'input_ids': pad_token_id, - 'labels': -100, - 'attention_mask': 0, - 'sequence_id': -1, - } - keys = packed_examples[0].keys() - batch = {} - for key in keys: - batch[key] = torch.stack([ - pad_tensor(example[key], pad_vals[key]) - for example in packed_examples - ]) - return batch +def _pad_tensor(tensor: torch.Tensor, pad_value: int, max_seq_len: int, padding_side: str): + if len(tensor) == max_seq_len: + return tensor + t = torch.full((max_seq_len,), + pad_value, + dtype=tensor.dtype, + device=tensor.device) + if padding_side == 'left': + t[-len(tensor):] = tensor + elif padding_side == 'right': + t[:len(tensor)] = tensor + else: + raise ValueError(f'Unknown {padding_side=}') + return t def auto_packing_ratio( @@ -428,24 +440,23 @@ def profile_packing( 'prefetch_factor': None, 'persistent_workers': False, }) - dataloader_cfg['dataset']['packing_ratio'] = 1.0 - dataloader_cfg['dataset']['auto_packing_replication' - ] = dataloader_cfg['dataset'].get( - 'seq_parallel_replication', - 1, - ) or 1 - dataloader_cfg['dataset']['seq_parallel_replication'] = 1 + dataset_cfg = dataloader_cfg['dataset'] + dataset_cfg['packing_ratio'] = 1.0 + seq_parallel_replication = dataset_cfg.get('seq_parallel_replication', 1) + dataset_cfg['auto_packing_replication'] = seq_parallel_replication or 1 + dataset_cfg['seq_parallel_replication'] = 1 + dataset_cfg['pad_to_longest'] = True # If streaming dataset, use a temporary local folder for profiling local_rank_zero = dist.get_global_rank() - dist.get_local_rank() - if dataloader_cfg['dataset'].get('remote') is not None: + if dataset_cfg.get('remote') is not None: tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - dataloader_cfg['dataset']['local'] = tmp_path + dataset_cfg['local'] = tmp_path - if dataloader_cfg['dataset'].get('streams') is not None: - for stream_config in dataloader_cfg['dataset']['streams'].values(): + if dataset_cfg.get('streams') is not None: + for stream_config in dataset_cfg['streams'].values(): tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] @@ -492,6 +503,7 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: pad_token_id=0, # <-- Doesn't need to be correct for profiling padding_side='left', # <-- Doesn't need to be correct for profiling max_leftover_bins_to_keep=max_leftovers_to_keep, + is_profiling=False, ) # Simulate feeding the packing collator a bunch of data