diff --git a/src/slam_llm/models/slam_model.py b/src/slam_llm/models/slam_model.py index 6aa310a9..b0547278 100644 --- a/src/slam_llm/models/slam_model.py +++ b/src/slam_llm/models/slam_model.py @@ -367,13 +367,14 @@ def forward(self, inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) if modality_mask is not None: - modality_unmask_start = (modality_mask == True).float().argmax(dim=1) - + modality_unmask_start_indices = (modality_mask == True).float().argmax(dim=1) + modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() + encoder_outs_pad = torch.zeros_like(inputs_embeds) for i in range(encoder_outs.shape[0]): encoder_outs_pad[ - i, modality_unmask_start[i]:modality_unmask_start[i]+modality_mask[i].sum().item() - ] = encoder_outs[i] + i, modality_unmask_start_indices[i] : modality_unmask_start_indices[i] + modality_lengths[i] + ] = encoder_outs[i][:modality_lengths[i]] inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])