Skip to content

Commit

Permalink
add torch_device to integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Dec 5, 2024
1 parent df94db3 commit 879fe3e
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/models/got_ocr2/test_modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ def tearDown(self):
@slow
def test_small_model_integration_test_got_ocr_stop_strings(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/iam_picture.jpeg"
)

inputs = self.processor(image, return_tensors="pt")
inputs = self.processor(image, return_tensors="pt").to(torch_device)
generate_ids = model.generate(
**inputs,
do_sample=False,
Expand All @@ -268,12 +268,12 @@ def test_small_model_integration_test_got_ocr_stop_strings(self):
@slow
def test_small_model_integration_test_got_ocr_format(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
)

inputs = self.processor(image, return_tensors="pt", format=True)
inputs = self.processor(image, return_tensors="pt", format=True).to(torch_device)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = self.processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
Expand All @@ -284,12 +284,12 @@ def test_small_model_integration_test_got_ocr_format(self):
@slow
def test_small_model_integration_test_got_ocr_fine_grained(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
)

inputs = self.processor(image, return_tensors="pt", color="green")
inputs = self.processor(image, return_tensors="pt", color="green").to(torch_device)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = self.processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
Expand All @@ -300,12 +300,12 @@ def test_small_model_integration_test_got_ocr_fine_grained(self):
@slow
def test_small_model_integration_test_got_ocr_crop_to_patches(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
)

inputs = self.processor(image, return_tensors="pt", crop_to_patches=True)
inputs = self.processor(image, return_tensors="pt", crop_to_patches=True).to(torch_device)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = self.processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
Expand All @@ -316,15 +316,15 @@ def test_small_model_integration_test_got_ocr_crop_to_patches(self):
@slow
def test_small_model_integration_test_got_ocr_multi_pages(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image1 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
)
image2 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
)

inputs = self.processor([image1, image2], return_tensors="pt", multi_page=True)
inputs = self.processor([image1, image2], return_tensors="pt", multi_page=True).to(torch_device)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = self.processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
Expand All @@ -335,15 +335,15 @@ def test_small_model_integration_test_got_ocr_multi_pages(self):
@slow
def test_small_model_integration_test_got_ocr_batched(self):
model_id = "yonigozlan/GOT-OCR-2.0-hf"
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id)
model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
image1 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
)
image2 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
)

inputs = self.processor([image1, image2], return_tensors="pt")
inputs = self.processor([image1, image2], return_tensors="pt").to(torch_device)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = self.processor.batch_decode(
generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True
Expand Down

0 comments on commit 879fe3e

Please sign in to comment.