Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz committed Oct 29, 2024
1 parent 03495a8 commit 4c78eca
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/models/mask2former/test_modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
from packaging import version

from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
Expand Down Expand Up @@ -481,3 +482,29 @@ def test_with_segmentation_maps_and_loss(self):
outputs = model(**inputs)

self.assertTrue(outputs.loss is not None)

@slow
def test_export(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device)

exported_program = torch.export.export(
model,
args=(inputs["pixel_values"], inputs["pixel_mask"]),
strict=True,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
)
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
)

0 comments on commit 4c78eca

Please sign in to comment.