From 9083c2b85ecf6936ded66cb9f159c9da02501a74 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 25 Oct 2024 10:11:29 +0200 Subject: [PATCH] add test --- tests/onnxruntime/test_modeling.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 33243da278a..da450b8e31c 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2192,6 +2192,18 @@ def test_compare_to_io_binding(self, model_arch): gc.collect() + def test_default_token_type_ids(self): + model_id = MODEL_NAMES["bert"] + model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("this is a simple input", return_tensors="np") + self.assertTrue("token_type_ids" in model.input_names) + token_type_ids = tokens.pop("token_type_ids") + outs = model(token_type_ids=token_type_ids, **tokens) + outs_without_token_type_ids = model(**tokens) + self.assertTrue(np.allclose(outs.last_hidden_state, outs_without_token_type_ids.last_hidden_state)) + gc.collect() + class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin): # Multiple Choice tests are conducted on different models due to mismatch size in model's classifier