Skip to content

Commit

Permalink
[Bugfix] Remove xformers requirement for Pixtral (vllm-project#9597)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Oct 24, 2024
1 parent 5944909 commit c91ed47
Showing 1 changed file with 46 additions and 19 deletions.
65 changes: 46 additions & 19 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
Expand All @@ -38,6 +36,12 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
Expand Down Expand Up @@ -416,7 +420,7 @@ def __init__(self, args: VisionEncoderArgs):
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
batch, patches, _ = x.shape
Expand All @@ -427,7 +431,7 @@ def forward(
v = v.reshape(batch, patches, self.n_heads, self.head_dim)

q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = memory_efficient_attention(q, k, v, attn_bias=mask)
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)

Expand All @@ -444,7 +448,7 @@ def __init__(self, args: VisionEncoderArgs):
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x),
Expand All @@ -467,7 +471,7 @@ def __init__(self, args: VisionEncoderArgs):
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
mask: torch.Tensor,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
Expand Down Expand Up @@ -562,8 +566,12 @@ def forward(
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

# pass through Transformer with a block diagonal mask delimiting images
mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
if USE_XFORMERS_OPS:
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
raise ImportError("Xformers is required for Pixtral inference "
"with the Mistral format")
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

# remove batch dimension of the single sequence
Expand Down Expand Up @@ -828,7 +836,7 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: BlockDiagonalMask,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()
Expand All @@ -843,12 +851,23 @@ def forward(
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)

out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.reshape(batch, patches, self.n_heads,
self.head_dim).transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)

out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)

return self.o_proj(out)
Expand Down Expand Up @@ -877,7 +896,7 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: BlockDiagonalMask,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
Expand Down Expand Up @@ -916,7 +935,7 @@ def __init__(
def forward(
self,
x: torch.Tensor,
attention_mask: BlockDiagonalMask,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
for layer in self.layers:
Expand Down Expand Up @@ -1000,11 +1019,19 @@ def forward(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size).to(
self.device)

position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids)
attention_mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )

if USE_XFORMERS_OPS:
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)

out = self.transformer(patch_embeds, attention_mask,
position_embedding)

Expand Down

0 comments on commit c91ed47

Please sign in to comment.