Skip to content

Commit

Permalink
Enable paligemma2 (#2807)
Browse files Browse the repository at this point in the history
* feat: support loading gemma2 as vlm text model

* feat: add test for paligemma2
  • Loading branch information
drbh authored Dec 6, 2024
1 parent 08f6fa0 commit 9f5c9a5
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 108,
"logprob": -0.73046875,
"special": false,
"text": "\n"
},
{
"id": 30234,
"logprob": -2.328125,
"special": false,
"text": "Brown"
},
{
"id": 108,
"logprob": -0.12060547,
"special": false,
"text": "\n"
},
{
"id": 3726,
"logprob": -1.7734375,
"special": false,
"text": "Car"
},
{
"id": 108,
"logprob": -0.041503906,
"special": false,
"text": "\n"
},
{
"id": 2915,
"logprob": -1.796875,
"special": false,
"text": "Color"
},
{
"id": 108,
"logprob": -0.039794922,
"special": false,
"text": "\n"
},
{
"id": 19178,
"logprob": -1.96875,
"special": false,
"text": "Cool"
},
{
"id": 108,
"logprob": -0.080566406,
"special": false,
"text": "\n"
},
{
"id": 40544,
"logprob": -2.1875,
"special": false,
"text": "Decor"
},
{
"id": 108,
"logprob": -0.033935547,
"special": false,
"text": "\n"
},
{
"id": 13936,
"logprob": -1.6328125,
"special": false,
"text": "Green"
},
{
"id": 108,
"logprob": -0.16210938,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -2.015625,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.14746094,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.73828125,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.051513672,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.34765625,
"special": false,
"text": "..."
},
{
"id": 108,
"logprob": -0.020141602,
"special": false,
"text": "\n"
},
{
"id": 955,
"logprob": -0.11767578,
"special": false,
"text": "..."
}
],
"top_tokens": null
},
"generated_text": "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
}
29 changes: 29 additions & 0 deletions integration-tests/models/test_flash_pali_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest


@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma2-3b-pt-224",
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client


async def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot):
car_image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
response = await flash_pali_gemma.generate(
f"![]({car_image})",
max_new_tokens=20,
)
assert (
response.generated_text
== "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..."
)

assert response == response_snapshot
6 changes: 6 additions & 0 deletions server/text_generation_server/models/custom_modeling/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None):
)

return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)

return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
Expand Down

0 comments on commit 9f5c9a5

Please sign in to comment.