Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 2: Update slow tests #31759

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,17 @@ def preprocess(
max_length=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
}
tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None}

if isinstance(prompt_text, Chat):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
add_generation_prompt=True,
Expand All @@ -286,11 +291,6 @@ def preprocess(
**tokenizer_kwargs,
)
else:
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)

inputs["prompt_text"] = prompt_text
Expand Down
51 changes: 36 additions & 15 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.testing_utils import (
require_read_token,
require_torch,
Expand Down Expand Up @@ -102,41 +102,62 @@ def setUpClass(cls):
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

@require_read_token
def test_model_2b_bf16(self):
def test_model_9b_bf16(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project for a class and I am trying to use the <code><a-image></code>",
"<pad><pad><bos>Hi today. So, I'm going to show you how to do a problem from the textbook. So",
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
Copy link
Member Author

@gante gante Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: BF16 + attn_implementation!="eager" results in the model not being able to handle left-padding 🤯

).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_read_token
def test_model_2b_fp16(self):
def test_model_9b_fp16(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the effect of the temperature on the rate of a reaction. I am using a ",
"<pad><pad><bos>Hi today I'm going to be talking about the 1000-4000-",
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, attn_implementation="eager"
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_read_token
def test_model_9b_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-9b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
Loading