From 879fe3e4f2bb668bfb7b360852966549561a3c56 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 5 Dec 2024 18:15:52 +0000 Subject: [PATCH] add torch_device to integration tests --- .../models/got_ocr2/test_modeling_got_ocr2.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index 7163cb4633f260..beccbab74fae9c 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -245,12 +245,12 @@ def tearDown(self): @slow def test_small_model_integration_test_got_ocr_stop_strings(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/iam_picture.jpeg" ) - inputs = self.processor(image, return_tensors="pt") + inputs = self.processor(image, return_tensors="pt").to(torch_device) generate_ids = model.generate( **inputs, do_sample=False, @@ -268,12 +268,12 @@ def test_small_model_integration_test_got_ocr_stop_strings(self): @slow def test_small_model_integration_test_got_ocr_format(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" ) - inputs = self.processor(image, return_tensors="pt", format=True) + inputs = self.processor(image, return_tensors="pt", format=True).to(torch_device) generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) decoded_output = self.processor.decode( generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True @@ -284,12 +284,12 @@ def test_small_model_integration_test_got_ocr_format(self): @slow def test_small_model_integration_test_got_ocr_fine_grained(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" ) - inputs = self.processor(image, return_tensors="pt", color="green") + inputs = self.processor(image, return_tensors="pt", color="green").to(torch_device) generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) decoded_output = self.processor.decode( generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True @@ -300,12 +300,12 @@ def test_small_model_integration_test_got_ocr_fine_grained(self): @slow def test_small_model_integration_test_got_ocr_crop_to_patches(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" ) - inputs = self.processor(image, return_tensors="pt", crop_to_patches=True) + inputs = self.processor(image, return_tensors="pt", crop_to_patches=True).to(torch_device) generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) decoded_output = self.processor.decode( generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True @@ -316,7 +316,7 @@ def test_small_model_integration_test_got_ocr_crop_to_patches(self): @slow def test_small_model_integration_test_got_ocr_multi_pages(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image1 = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" ) @@ -324,7 +324,7 @@ def test_small_model_integration_test_got_ocr_multi_pages(self): "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" ) - inputs = self.processor([image1, image2], return_tensors="pt", multi_page=True) + inputs = self.processor([image1, image2], return_tensors="pt", multi_page=True).to(torch_device) generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) decoded_output = self.processor.decode( generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True @@ -335,7 +335,7 @@ def test_small_model_integration_test_got_ocr_multi_pages(self): @slow def test_small_model_integration_test_got_ocr_batched(self): model_id = "yonigozlan/GOT-OCR-2.0-hf" - model = GotOcr2ForConditionalGeneration.from_pretrained(model_id) + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) image1 = load_image( "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" ) @@ -343,7 +343,7 @@ def test_small_model_integration_test_got_ocr_batched(self): "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" ) - inputs = self.processor([image1, image2], return_tensors="pt") + inputs = self.processor([image1, image2], return_tensors="pt").to(torch_device) generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) decoded_output = self.processor.batch_decode( generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True