Skip to content

Commit

Permalink
[Model][Bugfix] Support TP for PixtralHF ViT (vllm-project#10405)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored and prashantgupta24 committed Dec 3, 2024
1 parent e19e15a commit a00d8c1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
Expand Down Expand Up @@ -843,17 +844,20 @@ def __init__(

self.config = config
assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads

self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.n_heads,
total_num_heads=self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
Expand Down

0 comments on commit a00d8c1

Please sign in to comment.