Skip to content

Commit

Permalink
enable QA bf16 pipeline (#34483)
Browse files Browse the repository at this point in the history
* enable QA bf16 pipeline

* add tests
  • Loading branch information
jiqing-feng authored Oct 31, 2024
1 parent 405b562 commit f385316
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,14 @@ def postprocess(
min_null_score = 1000000 # large and positive
answers = []
for output in model_outputs:
start_ = output["start"]
end_ = output["end"]
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
start_ = output["start"].to(torch.float32)
else:
start_ = output["start"]
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
end_ = output["end"].to(torch.float32)
else:
end_ = output["end"]
example = output["example"]
p_mask = output["p_mask"]
attention_mask = (
Expand Down
33 changes: 33 additions & 0 deletions tests/pipelines/test_pipelines_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
require_torch_or_tf,
slow,
)


if is_torch_available():
import torch

from .test_pipelines_common import ANY


Expand Down Expand Up @@ -165,6 +170,34 @@ def test_small_model_pt(self):

self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})

@require_torch
def test_small_model_pt_fp16(self):
question_answerer = pipeline(
"question-answering",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
torch_dtype=torch.float16,
)

outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)

self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})

@require_torch
def test_small_model_pt_bf16(self):
question_answerer = pipeline(
"question-answering",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
torch_dtype=torch.bfloat16,
)

outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)

self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})

@require_torch
def test_small_model_pt_iterator(self):
# https://github.com/huggingface/transformers/issues/18510
Expand Down

0 comments on commit f385316

Please sign in to comment.