Skip to content

Commit

Permalink
Gemma 2: Update slow tests (#31759)
Browse files Browse the repository at this point in the history
gemma 2 slow tests
  • Loading branch information
gante authored Jul 3, 2024
1 parent c1fe125 commit ddfaf11
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
20 changes: 10 additions & 10 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,17 @@ def preprocess(
max_length=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
}
tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None}

if isinstance(prompt_text, Chat):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
add_generation_prompt=True,
Expand All @@ -286,11 +291,6 @@ def preprocess(
**tokenizer_kwargs,
)
else:
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)

inputs["prompt_text"] = prompt_text
Expand Down
51 changes: 36 additions & 15 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.testing_utils import (
require_read_token,
require_torch,
Expand Down Expand Up @@ -102,41 +102,62 @@ def setUpClass(cls):
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

@require_read_token
def test_model_2b_bf16(self):
def test_model_9b_bf16(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project for a class and I am trying to use the <code><a-image></code>",
"<pad><pad><bos>Hi today. So, I'm going to show you how to do a problem from the textbook. So",
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_read_token
def test_model_2b_fp16(self):
def test_model_9b_fp16(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the effect of the temperature on the rate of a reaction. I am using a ",
"<pad><pad><bos>Hi today I'm going to be talking about the 1000-4000-",
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, attn_implementation="eager"
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_read_token
def test_model_9b_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-9b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])

0 comments on commit ddfaf11

Please sign in to comment.