Skip to content

Commit

Permalink
Adding optimum option for PredictEngine (#492)
Browse files Browse the repository at this point in the history
* adding optimum onnx option for classification

* removing a few unneeded bits from optimum.py

* adding missing required args from optimum pipeline to make sure it gives the same output as torch implementation

* adding unit test

* remove a few stray things that are not needed

* fixing code styles

* minor linting fix

* fix failing test

---------

Co-authored-by: wendy mak <[email protected]>
  • Loading branch information
wwymak and wendy mak authored Dec 10, 2024
1 parent efe6096 commit c335df8
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
91 changes: 91 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeil

import copy
import os

from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS
from infinity_emb.args import EngineArgs
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.utils_optimum import (
device_to_onnx,
get_onnx_files,
optimize_model,
)

if CHECK_ONNXRUNTIME.is_available:
try:
from optimum.onnxruntime import ( # type: ignore[import-untyped]
ORTModelForSequenceClassification,
)

except (ImportError, RuntimeError, Exception) as ex:
CHECK_ONNXRUNTIME.mark_dirty(ex)

if CHECK_TRANSFORMERS.is_available:
from transformers import AutoTokenizer, pipeline # type: ignore[import-untyped]


class OptimumClassifier(BaseClassifer):
def __init__(self, *, engine_args: EngineArgs):
CHECK_ONNXRUNTIME.mark_required()
CHECK_TRANSFORMERS.mark_required()
provider = device_to_onnx(engine_args.device)

onnx_file = get_onnx_files(
model_name_or_path=engine_args.model_name_or_path,
revision=engine_args.revision,
use_auth_token=True,
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
)

self.model = optimize_model(
model_name_or_path=engine_args.model_name_or_path,
model_class=ORTModelForSequenceClassification,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
execution_provider=provider,
file_name=onnx_file.as_posix(),
optimize_model=not os.environ.get("INFINITY_ONNX_DISABLE_OPTIMIZE", False),
)
self.model.use_io_binding = False

self.tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
)

self._infinity_tokenizer = copy.deepcopy(self.tokenizer)

self._pipe = pipeline(
task="text-classification",
model=self.model,
trust_remote_code=engine_args.trust_remote_code,
top_k=None,
revision=engine_args.revision,
tokenizer=self.tokenizer,
device=engine_args.device,
)

def encode_pre(self, sentences: list[str]):
return sentences

def encode_core(self, sentences: list[str]) -> dict:
outputs = self._pipe(sentences)
return outputs

def encode_post(self, classes) -> dict[str, float]:
"""runs post encoding such as normalization"""
return classes

def tokenize_lengths(self, sentences: list[str]) -> list[int]:
"""gets the lengths of each sentences according to tokenize/len etc."""
tks = self._infinity_tokenizer.batch_encode_plus(
sentences,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
return_length=False,
).encodings
return [len(t.tokens) for t in tks]
4 changes: 4 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from infinity_emb.primitives import InferenceEngine
from infinity_emb.transformer.audio.torch import TorchAudioModel
from infinity_emb.transformer.classifier.torch import SentenceClassifier
from infinity_emb.transformer.classifier.optimum import OptimumClassifier
from infinity_emb.transformer.crossencoder.optimum import OptimumCrossEncoder
from infinity_emb.transformer.crossencoder.torch import (
CrossEncoderPatched as CrossEncoderTorch,
Expand Down Expand Up @@ -87,11 +88,14 @@ def from_inference_engine(engine: InferenceEngine):

class PredictEngine(Enum):
torch = SentenceClassifier
optimum = OptimumClassifier

@staticmethod
def from_inference_engine(engine: InferenceEngine):
if engine == InferenceEngine.torch:
return PredictEngine.torch
elif engine == InferenceEngine.optimum:
return PredictEngine.optimum
else:
raise NotImplementedError(f"PredictEngine for {engine} not implemented")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from optimum.pipelines import pipeline # type: ignore
from optimum.onnxruntime import ORTModelForSequenceClassification
from infinity_emb.args import EngineArgs
from infinity_emb.transformer.classifier.optimum import OptimumClassifier


def test_classifier(model_name: str = "SamLowe/roberta-base-go_emotions-onnx"):
model = OptimumClassifier(
engine_args=EngineArgs(
model_name_or_path=model_name,
device="cuda" if torch.cuda.is_available() else "cpu",
) # type: ignore
)

pipe = pipeline(
task="text-classification",
model=ORTModelForSequenceClassification.from_pretrained(
model_name, file_name="onnx/model_quantized.onnx"
),
top_k=None,
)

sentences = ["This is awesome.", "I am depressed."]

encode_pre = model.encode_pre(sentences)
encode_core = model.encode_core(encode_pre)
preds = model.encode_post(encode_core)

assert len(preds) == len(sentences)
assert isinstance(preds, list)
assert isinstance(preds[0], list)
assert isinstance(preds[0][0], dict)
assert isinstance(preds[0][0]["label"], str)
assert isinstance(preds[0][0]["score"], float)
assert preds[0][0]["label"] == "admiration"
assert 0.98 > preds[0][0]["score"] > 0.93

preds_orig = pipe(sentences, top_k=None, truncation=True)

assert len(preds_orig) == len(preds)

for pred_orig, pred in zip(preds_orig, preds):
assert len(pred_orig) == len(pred)
for pred_orig_i, pred_i in zip(pred_orig[:5], pred[:5]):
assert abs(pred_orig_i["score"] - pred_i["score"]) < 0.05

if pred_orig_i["score"] > 0.005:
assert pred_orig_i["label"] == pred_i["label"]

0 comments on commit c335df8

Please sign in to comment.