From f38531619ddff23a510d5f7ccbc257a1bb1a3cb7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 31 Oct 2024 20:55:53 +0800 Subject: [PATCH] enable QA bf16 pipeline (#34483) * enable QA bf16 pipeline * add tests --- .../pipelines/question_answering.py | 10 ++++-- .../test_pipelines_question_answering.py | 33 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 6039e5ad1ee989..7b876eefc49279 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -540,8 +540,14 @@ def postprocess( min_null_score = 1000000 # large and positive answers = [] for output in model_outputs: - start_ = output["start"] - end_ = output["end"] + if self.framework == "pt" and output["start"].dtype == torch.bfloat16: + start_ = output["start"].to(torch.float32) + else: + start_ = output["start"] + if self.framework == "pt" and output["start"].dtype == torch.bfloat16: + end_ = output["end"].to(torch.float32) + else: + end_ = output["end"] example = output["example"] p_mask = output["p_mask"] attention_mask = ( diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index d06f88d1f08844..bf4fc7db1db6b5 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( compare_pipeline_output_to_hub_spec, is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, @@ -34,6 +35,10 @@ slow, ) + +if is_torch_available(): + import torch + from .test_pipelines_common import ANY @@ -165,6 +170,34 @@ def test_small_model_pt(self): self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch + def test_small_model_pt_fp16(self): + question_answerer = pipeline( + "question-answering", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + torch_dtype=torch.float16, + ) + + outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + ) + + self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + + @require_torch + def test_small_model_pt_bf16(self): + question_answerer = pipeline( + "question-answering", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + torch_dtype=torch.bfloat16, + ) + + outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + ) + + self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch def test_small_model_pt_iterator(self): # https://github.com/huggingface/transformers/issues/18510