Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 12, 2023
1 parent 84aa4d2 commit 68e727c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"OVModelForFeatureExtraction",
"OVModelForImageClassification",
"OVModelForMaskedLM",
"OVModelForPix2Struct",
"OVModelForQuestionAnswering",
"OVModelForSeq2SeqLM",
"OVModelForSequenceClassification",
Expand Down
96 changes: 96 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
AutoModelForTokenClassification,
AutoTokenizer,
GenerationConfig,
Pix2StructForConditionalGeneration,
PretrainedConfig,
pipeline,
set_seed,
)
from transformers.onnx.utils import get_preprocessor
from utils_tests import MODEL_NAMES

from optimum.intel import (
Expand All @@ -58,6 +60,7 @@
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
OVModelForPix2Struct,
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
Expand Down Expand Up @@ -1073,3 +1076,96 @@ def test_compare_to_transformers(self, model_arch):
del transformers_model
del ov_model
gc.collect()


class OVModelForPix2StructIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["pix2struct"]
TASK = "image-to-text" # is it fine as well with visual-question-answering?

GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.1

IMAGE = Image.open(
requests.get(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
stream=True,
).raw
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ov_model = OVModelForPix2Struct.from_pretrained(model_id, export=True)

self.assertIsInstance(ov_model.encoder, OVEncoder)
self.assertIsInstance(ov_model.decoder, OVDecoder)
self.assertIsInstance(ov_model.decoder_with_past, OVDecoder)
self.assertIsInstance(ov_model.config, PretrainedConfig)

question = "Who am I?"
transformers_model = Pix2StructForConditionalGeneration.from_pretrained(model_id)
preprocessor = get_preprocessor(model_id)

inputs = preprocessor(images=self.IMAGE, text=question, padding=True, return_tensors="pt")
ov_outputs = ov_model(**inputs)

self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)

with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
del transformers_model
del ov_model

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = OVModelForPix2Struct.from_pretrained(model_id, export=True)
preprocessor = get_preprocessor(model_id)
question = "Who am I?"
inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt")

# General case
outputs = model.generate(**inputs)
outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(outputs[0], str)
del model

gc.collect()

def test_compare_with_and_without_past_key_values(self):
model_id = MODEL_NAMES["pix2struct"]
preprocessor = get_preprocessor(model_id)
question = "Who am I?"
inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt")

model_with_pkv = OVModelForPix2Struct.from_pretrained(model_id, export=True, use_cache=True)
_ = model_with_pkv.generate(**inputs) # warmup
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**inputs, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = OVModelForPix2Struct.from_pretrained(model_id, export=True, use_cache=False)
_ = model_without_pkv.generate(**inputs) # warmup
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**inputs, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(
without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE,
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)
del model_with_pkv
del model_without_pkv
gc.collect()
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"mt5": "stas/mt5-tiny-random",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-pegasus",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
Expand Down

0 comments on commit 68e727c

Please sign in to comment.