Skip to content

Commit

Permalink
Paligemma - fix slow tests, add bf16 and f16 slow tests (#30851)
Browse files Browse the repository at this point in the history
* fix slow tests, add bf16 and f16 slow tests

* few fixes

* [run-slow]paligemma

* add gate decorator

* [run-slow]paligemma

* add missing gating

* [run-slow]paligemma

* [run-slow]paligemma
  • Loading branch information
molbap authored May 22, 2024
1 parent ada86f9 commit 250ae9f
Showing 1 changed file with 89 additions and 61 deletions.
150 changes: 89 additions & 61 deletions tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_read_token,
require_torch,
require_torch_sdpa,
slow,
Expand Down Expand Up @@ -260,98 +260,88 @@ def test_save_load_low_cpu_mem_usage_no_safetensors(self):

@slow
@require_torch
@require_read_token
class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = PaliGemmaProcessor.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@slow
@require_bitsandbytes
@require_read_token
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
prompt = ""
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt")
# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 2, 108]])
# fmt: on
EXPECTED_INPUT_IDS = torch.tensor([[257152] * 256 + [2, 108]])
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\ncow standing on the beach" # fmt: skip
EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_bitsandbytes
def test_small_model_integration_test_paligemma(self):
@require_read_token
def test_small_model_integration_test_paligemma_VQA(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "gv-hf/PaliGemma-test-224px-hf"

model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
processor = PaliGemmaProcessor.from_pretrained(model_id)

model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
prompt = "answer en Where is the cow standing?"
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "answer en Where is the cow standing?\nbeach" # fmt: skip

self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_read_token
def test_small_model_integration_test_paligemma_empty_prompt(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)

prompt = ""
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_bitsandbytes
@require_read_token
def test_small_model_integration_test_paligemma_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "gv-hf/PaliGemma-test-224px-hf"
model_id = "google/paligemma-3b-pt-224"

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)

prompts = [
"answer en Where is the cow standing?",
Expand All @@ -365,19 +355,23 @@ def test_small_model_integration_test_paligemma_batched(self):
)
image2 = image1

inputs = processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
inputs = self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
def test_small_model_integration_test_batch(self):
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_bf16(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
).to(torch_device)
# The first batch is longer in terms of text, the second will be padded.
prompts = [
"answer en Where is the cow standing?",
Expand All @@ -391,32 +385,66 @@ def test_small_model_integration_test_batch(self):
)
image2 = image1

inputs = self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
.to(torch.bfloat16)
.to(torch_device)
)
output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_f16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="float16", torch_dtype=torch.float16
).to(torch_device)
# The first batch is longer in terms of text, the second will be padded.
prompts = [
"answer en Where is the cow standing?",
"",
]
image1 = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
stream=True,
).raw
)
image2 = image1

inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
.to(torch.float16)
.to(torch_device)
)

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
@require_read_token
def test_paligemma_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
# more details
model_id = "gv-hf/PaliGemma-test-224px-hf"
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)

processor = PaliGemmaProcessor.from_pretrained(model_id)

# Simulate a super long prompt
prompt = "\n" * 200
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(
inputs = self.processor(
text=prompt,
images=raw_image,
return_tensors="pt",
Expand Down

0 comments on commit 250ae9f

Please sign in to comment.