forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Add PaliGemma (vllm-project#5189)
Co-authored-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
9389380
commit 6206dcb
Showing
6 changed files
with
557 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import subprocess | ||
|
||
from PIL import Image | ||
|
||
from vllm import LLM | ||
|
||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. | ||
# You can use `.buildkite/download-images.sh` to download them | ||
|
||
|
||
def run_paligemma(): | ||
llm = LLM(model="google/paligemma-3b-mix-224") | ||
|
||
prompt = "caption es" | ||
|
||
image = Image.open("images/stop_sign.jpg") | ||
|
||
outputs = llm.generate({ | ||
"prompt": prompt, | ||
"multi_modal_data": { | ||
"image": image | ||
}, | ||
}) | ||
|
||
for o in outputs: | ||
generated_text = o.outputs[0].text | ||
print(generated_text) | ||
|
||
|
||
def main(): | ||
run_paligemma() | ||
|
||
|
||
if __name__ == "__main__": | ||
# Download from s3 | ||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/" | ||
local_directory = "images" | ||
|
||
# Make sure the local directory exists or create it | ||
os.makedirs(local_directory, exist_ok=True) | ||
|
||
# Use AWS CLI to sync the directory, assume anonymous access | ||
subprocess.check_call([ | ||
"aws", | ||
"s3", | ||
"sync", | ||
s3_bucket_path, | ||
local_directory, | ||
"--no-sign-request", | ||
]) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import List, Optional, Tuple, Type | ||
|
||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from vllm.multimodal.utils import rescale_image_size | ||
from vllm.sequence import SampleLogprobs | ||
|
||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets | ||
from .utils import check_logprobs_close | ||
|
||
pytestmark = pytest.mark.vlm | ||
|
||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ | ||
"stop_sign": "caption es", | ||
"cherry_blossom": "What is in the picture?", | ||
"boardwalk": "What is in the picture?", | ||
}) | ||
|
||
IMAGE_TOKEN_ID = 257152 | ||
|
||
models = ["google/paligemma-3b-mix-224"] | ||
|
||
|
||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, | ||
Optional[SampleLogprobs]], | ||
model: str): | ||
"""Sanitize vllm output to be comparable with hf output.""" | ||
output_ids, output_str, out_logprobs = vllm_output | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model) | ||
eos_token_id = tokenizer.eos_token_id | ||
|
||
hf_output_ids = [ | ||
token_id for idx, token_id in enumerate(output_ids) | ||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID | ||
] | ||
|
||
hf_output_str = output_str | ||
|
||
if hf_output_ids[-1] == eos_token_id: | ||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) | ||
|
||
return hf_output_ids, hf_output_str, out_logprobs | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
image_assets: _ImageAssets, | ||
model: str, | ||
*, | ||
size_factors: List[float], | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
"""Inference result should be the same between hf and vllm. | ||
All the image fixtures for the test is under tests/images. | ||
For huggingface runner, we provide the PIL images as input. | ||
For vllm runner, we provide MultiModalDataDict objects | ||
and corresponding vision language config as input. | ||
Note, the text input is also adjusted to abide by vllm contract. | ||
The text output is sanitized to be able to compare with hf. | ||
""" | ||
images = [asset.pil_image for asset in image_assets] | ||
|
||
inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[rescale_image_size(image, factor) for factor in size_factors], | ||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] | ||
|
||
# NOTE: take care of the order. run vLLM first, and then run HF. | ||
# vLLM needs a fresh new process without cuda initialization. | ||
# if we run HF first, the cuda initialization will be done and it | ||
# will hurt multiprocessing backend with fork method (the default method). | ||
|
||
# max_model_len should be greater than image_feature_size | ||
with vllm_runner(model, | ||
dtype=dtype, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs_per_image = [ | ||
vllm_model.generate_greedy_logprobs(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in inputs_per_image | ||
] | ||
|
||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: | ||
hf_outputs_per_image = [ | ||
hf_model.generate_greedy_logprobs_limit(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in inputs_per_image | ||
] | ||
|
||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, | ||
vllm_outputs_per_image): | ||
|
||
check_logprobs_close( | ||
outputs_0_lst=hf_outputs, | ||
outputs_1_lst=[ | ||
vllm_to_hf_output(vllm_output, model) | ||
for vllm_output in vllm_outputs | ||
], | ||
name_0="hf", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
@pytest.mark.parametrize( | ||
"size_factors", | ||
[ | ||
# No image | ||
[], | ||
# Single-scale | ||
[1.0], | ||
# Single-scale, batched | ||
[1.0, 1.0, 1.0], | ||
# Multi-scale | ||
[0.25, 0.5, 1.0], | ||
], | ||
) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [128]) | ||
@pytest.mark.parametrize("num_logprobs", [5]) | ||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, | ||
dtype: str, max_tokens: int, num_logprobs: int) -> None: | ||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model, | ||
size_factors=size_factors, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.