Skip to content

Commit

Permalink
Fix modality padding mask bug (Question #82)
Browse files Browse the repository at this point in the history
  • Loading branch information
zszheng147 authored May 21, 2024
1 parent ba10359 commit c1b9e5d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/slam_llm/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,14 @@ def forward(self,
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

if modality_mask is not None:
modality_unmask_start = (modality_mask == True).float().argmax(dim=1)

modality_unmask_start_indices = (modality_mask == True).float().argmax(dim=1)
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()

encoder_outs_pad = torch.zeros_like(inputs_embeds)
for i in range(encoder_outs.shape[0]):
encoder_outs_pad[
i, modality_unmask_start[i]:modality_unmask_start[i]+modality_mask[i].sum().item()
] = encoder_outs[i]
i, modality_unmask_start_indices[i] : modality_unmask_start_indices[i] + modality_lengths[i]
] = encoder_outs[i][:modality_lengths[i]]

inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])

Expand Down

0 comments on commit c1b9e5d

Please sign in to comment.