diff --git a/llmfoundry/models/llm_embed/modeling_llm_embed.py b/llmfoundry/models/llm_embed/modeling_llm_embed.py index 694d46d184..eba4863c3a 100644 --- a/llmfoundry/models/llm_embed/modeling_llm_embed.py +++ b/llmfoundry/models/llm_embed/modeling_llm_embed.py @@ -62,7 +62,6 @@ class ContrastiveConfig: temperature (Union[int, float], optional): Temperature for InfoNCE Loss. Defaults to 1. vector_representation (str, optional): The vector representation to use. Defaults to 'avg'. normalize_output (bool, optional): Whether to normalize the output. Defaults to True. - pos_step_size (int, optional): The step size for positive samples. Defaults to 2. gather_in_batch_negatives (bool, optional): Whether to call all_gather on all samples in global batch use_legacy_gradient_passthrough (bool, optional): Whether to use the legacy gradient passthrough. Defaults to False. infonce_process_group_size (int, optional): The size of the process group for InfoNCE loss. Defaults to None. @@ -70,7 +69,7 @@ class ContrastiveConfig: temperature: Union[int, float] = 1 vector_representation: str = 'avg' normalize_output: bool = True - pos_step_size: int = 2 + pos_step_size: int = -1 #keep for backwards compatibility gather_in_batch_negatives: bool = False use_legacy_gradient_passthrough: bool = False infonce_process_group_size: Optional[int] = None @@ -157,7 +156,7 @@ def __init__( self.vector_representation = contrastive_config_obj.vector_representation self.normalize_output = contrastive_config_obj.normalize_output - self.step_size = contrastive_config_obj.pos_step_size + self.step_size = None self.gather_in_batch_negatives = contrastive_config_obj.gather_in_batch_negatives self.use_legacy_gradient_passthrough = contrastive_config_obj.use_legacy_gradient_passthrough self.n_active_params = sum(p.numel() for p in self.parameters()) @@ -209,6 +208,21 @@ def construct_model(self): self.is_mpt = True return model + def _update_step_size_if_needed(self, batch: MutableMapping) -> None: + """Update step size on first batch if we detect hard negatives.""" + if self.step_size: + return + + input_shape = batch['input_ids'].shape + if input_shape[1] > 2: + # We have hard negatives, batch shape is [batch, sample of query+positive passage+negative passages, tokens]. + self.step_size = input_shape[1] + log.info( + f'Detected hard negatives, updated step_size to {self.step_size}', + ) + else: + self.step_size = 2 + def format_queries_batch( self, batch: MutableMapping, @@ -219,13 +233,12 @@ def format_queries_batch( Here ``n`` is the step size, which represents the number of hard negatives per passage. """ + assert self.step_size queries = {} + indices = list(range(0, batch['input_ids'].size(0), self.step_size)) for key in batch: - # Select every `step_size`-th entry from the batch for the given key - queries[key] = batch[key][0::self.step_size, :] - - # Select every `step_size`-th entry from `last_hidden_state` along the batch dimension - return queries, last_hidden_state[0::self.step_size, :, :] + queries[key] = batch[key][indices, :] + return queries, last_hidden_state[indices, :, :] def format_passages_batch( self, @@ -237,24 +250,22 @@ def format_passages_batch( Here ``n`` is the step size, which represents the number of hard negatives per passage. """ + assert self.step_size passages = {} - - # Index on a variable step size - index = 0 + num_blocks = batch['input_ids'].size(0) // self.step_size + index = torch.arange( + 1, + num_blocks * self.step_size + 1, + device=last_hidden_state.device, + ).view(num_blocks, self.step_size) + index = index[:, :self.step_size - 1].reshape(-1) for key in batch: - num_blocks = batch[key].size(0) // self.step_size - index = torch.arange( - 1, - num_blocks * self.step_size + 1, - device=last_hidden_state.device, - ).view(num_blocks, self.step_size) - index = index[:, :self.step_size - 1].reshape(-1) passages[key] = batch[key][index] - return passages, last_hidden_state[index, :, :] def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: # Collapse pairs into the batch dimension + self._update_step_size_if_needed(batch) collapse_dims = lambda x: rearrange(x, 'b p d -> (b p) d') if \ len(x.shape) > 2 else x