diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f3d66c2313198..22d29f5bbc50c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -45,13 +45,6 @@ except ImportError: USE_XFORMERS_OPS = False -# These token ids cannot be retrieved from model config -# so we hardcode them here. -PIXTRAL_12B_IMAGE_BREAK_ID = 12 -PIXTRAL_12B_IMAGE_END_ID = 13 -PIXTRAL_LARGE_IMAGE_BREAK_ID = 14 -PIXTRAL_LARGE_IMAGE_END_ID = 15 - def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -201,6 +194,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if key in dataclass_fields } + if not ("image_break_token_id" in vision_args + and "image_end_token_id" in vision_args): + raise ValueError( + "'image_break_token_id' and 'image_end_token_id' not found " + "in the vision_encoder arguments. Please download the latest " + "version of 'params.json' from the model repository.") + self.vision_args = VisionEncoderArgs(**vision_args) # init MistralForCausalLM @@ -240,9 +240,8 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # NOTE: Image embeddings are split into separate tensors for each image # by the indices of `[IMG_END]` token. - image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | ( - image_tokens == PIXTRAL_LARGE_IMAGE_END_ID) - split_indices = torch.where(image_end_condition)[0] + 1 + image_end_mask = image_tokens == self.vision_args.image_end_token_id + split_indices = torch.where(image_end_mask)[0] + 1 if len(split_indices) <= 1: # Do not split, return as tensor of shape [1, fs, hs] return image_embeds.unsqueeze(0) @@ -265,10 +264,8 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ self.vision_args.image_token_id, - PIXTRAL_12B_IMAGE_END_ID, - PIXTRAL_12B_IMAGE_BREAK_ID, - PIXTRAL_LARGE_IMAGE_BREAK_ID, - PIXTRAL_LARGE_IMAGE_END_ID, + self.vision_args.image_break_token_id, + self.vision_args.image_end_token_id, ]) return inputs_embeds @@ -409,6 +406,8 @@ class VisionEncoderArgs: num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int + image_break_token_id: int + image_end_token_id: int adapter_bias: bool = True