-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding optimum option for PredictEngine (#492)
* adding optimum onnx option for classification * removing a few unneeded bits from optimum.py * adding missing required args from optimum pipeline to make sure it gives the same output as torch implementation * adding unit test * remove a few stray things that are not needed * fixing code styles * minor linting fix * fix failing test --------- Co-authored-by: wendy mak <[email protected]>
- Loading branch information
Showing
3 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
91 changes: 91 additions & 0 deletions
91
libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2023-now michaelfeil | ||
|
||
import copy | ||
import os | ||
|
||
from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS | ||
from infinity_emb.args import EngineArgs | ||
from infinity_emb.transformer.abstract import BaseClassifer | ||
from infinity_emb.transformer.utils_optimum import ( | ||
device_to_onnx, | ||
get_onnx_files, | ||
optimize_model, | ||
) | ||
|
||
if CHECK_ONNXRUNTIME.is_available: | ||
try: | ||
from optimum.onnxruntime import ( # type: ignore[import-untyped] | ||
ORTModelForSequenceClassification, | ||
) | ||
|
||
except (ImportError, RuntimeError, Exception) as ex: | ||
CHECK_ONNXRUNTIME.mark_dirty(ex) | ||
|
||
if CHECK_TRANSFORMERS.is_available: | ||
from transformers import AutoTokenizer, pipeline # type: ignore[import-untyped] | ||
|
||
|
||
class OptimumClassifier(BaseClassifer): | ||
def __init__(self, *, engine_args: EngineArgs): | ||
CHECK_ONNXRUNTIME.mark_required() | ||
CHECK_TRANSFORMERS.mark_required() | ||
provider = device_to_onnx(engine_args.device) | ||
|
||
onnx_file = get_onnx_files( | ||
model_name_or_path=engine_args.model_name_or_path, | ||
revision=engine_args.revision, | ||
use_auth_token=True, | ||
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()), | ||
) | ||
|
||
self.model = optimize_model( | ||
model_name_or_path=engine_args.model_name_or_path, | ||
model_class=ORTModelForSequenceClassification, | ||
revision=engine_args.revision, | ||
trust_remote_code=engine_args.trust_remote_code, | ||
execution_provider=provider, | ||
file_name=onnx_file.as_posix(), | ||
optimize_model=not os.environ.get("INFINITY_ONNX_DISABLE_OPTIMIZE", False), | ||
) | ||
self.model.use_io_binding = False | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
engine_args.model_name_or_path, | ||
revision=engine_args.revision, | ||
trust_remote_code=engine_args.trust_remote_code, | ||
) | ||
|
||
self._infinity_tokenizer = copy.deepcopy(self.tokenizer) | ||
|
||
self._pipe = pipeline( | ||
task="text-classification", | ||
model=self.model, | ||
trust_remote_code=engine_args.trust_remote_code, | ||
top_k=None, | ||
revision=engine_args.revision, | ||
tokenizer=self.tokenizer, | ||
device=engine_args.device, | ||
) | ||
|
||
def encode_pre(self, sentences: list[str]): | ||
return sentences | ||
|
||
def encode_core(self, sentences: list[str]) -> dict: | ||
outputs = self._pipe(sentences) | ||
return outputs | ||
|
||
def encode_post(self, classes) -> dict[str, float]: | ||
"""runs post encoding such as normalization""" | ||
return classes | ||
|
||
def tokenize_lengths(self, sentences: list[str]) -> list[int]: | ||
"""gets the lengths of each sentences according to tokenize/len etc.""" | ||
tks = self._infinity_tokenizer.batch_encode_plus( | ||
sentences, | ||
add_special_tokens=False, | ||
return_token_type_ids=False, | ||
return_attention_mask=False, | ||
return_length=False, | ||
).encodings | ||
return [len(t.tokens) for t in tks] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
libs/infinity_emb/tests/unit_test/transformer/classifier/test_optimum_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from optimum.pipelines import pipeline # type: ignore | ||
from optimum.onnxruntime import ORTModelForSequenceClassification | ||
from infinity_emb.args import EngineArgs | ||
from infinity_emb.transformer.classifier.optimum import OptimumClassifier | ||
|
||
|
||
def test_classifier(model_name: str = "SamLowe/roberta-base-go_emotions-onnx"): | ||
model = OptimumClassifier( | ||
engine_args=EngineArgs( | ||
model_name_or_path=model_name, | ||
device="cuda" if torch.cuda.is_available() else "cpu", | ||
) # type: ignore | ||
) | ||
|
||
pipe = pipeline( | ||
task="text-classification", | ||
model=ORTModelForSequenceClassification.from_pretrained( | ||
model_name, file_name="onnx/model_quantized.onnx" | ||
), | ||
top_k=None, | ||
) | ||
|
||
sentences = ["This is awesome.", "I am depressed."] | ||
|
||
encode_pre = model.encode_pre(sentences) | ||
encode_core = model.encode_core(encode_pre) | ||
preds = model.encode_post(encode_core) | ||
|
||
assert len(preds) == len(sentences) | ||
assert isinstance(preds, list) | ||
assert isinstance(preds[0], list) | ||
assert isinstance(preds[0][0], dict) | ||
assert isinstance(preds[0][0]["label"], str) | ||
assert isinstance(preds[0][0]["score"], float) | ||
assert preds[0][0]["label"] == "admiration" | ||
assert 0.98 > preds[0][0]["score"] > 0.93 | ||
|
||
preds_orig = pipe(sentences, top_k=None, truncation=True) | ||
|
||
assert len(preds_orig) == len(preds) | ||
|
||
for pred_orig, pred in zip(preds_orig, preds): | ||
assert len(pred_orig) == len(pred) | ||
for pred_orig_i, pred_i in zip(pred_orig[:5], pred[:5]): | ||
assert abs(pred_orig_i["score"] - pred_i["score"]) < 0.05 | ||
|
||
if pred_orig_i["score"] > 0.005: | ||
assert pred_orig_i["label"] == pred_i["label"] |