Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding optimum option for PredictEngine #492

Merged
merged 9 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeil

import copy
import os

import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: numpy is imported but never used in this file


from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS
from infinity_emb.args import EngineArgs
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.utils_optimum import (
device_to_onnx,
get_onnx_files,
optimize_model,
)

if CHECK_ONNXRUNTIME.is_available:
try:
from optimum.onnxruntime import ( # type: ignore[import-untyped]
ORTModelForSequenceClassification,
)

except (ImportError, RuntimeError, Exception) as ex:
CHECK_ONNXRUNTIME.mark_dirty(ex)

if CHECK_TRANSFORMERS.is_available:
from transformers import AutoConfig, AutoTokenizer, pipeline # type: ignore[import-untyped]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: AutoConfig is imported but never used



class OptimumClassifier(BaseClassifer):
def __init__(self, *, engine_args: EngineArgs):
CHECK_ONNXRUNTIME.mark_required()
CHECK_TRANSFORMERS.mark_required()
provider = device_to_onnx(engine_args.device)

onnx_file = get_onnx_files(
model_name_or_path=engine_args.model_name_or_path,
revision=engine_args.revision,
use_auth_token=True,
prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()),
)

self.model = optimize_model(
model_name_or_path=engine_args.model_name_or_path,
model_class=ORTModelForSequenceClassification,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
execution_provider=provider,
file_name=onnx_file.as_posix(),
optimize_model=not os.environ.get(
"INFINITY_ONNX_DISABLE_OPTIMIZE", False
),
)
self.model.use_io_binding = False

self.tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
)

self._infinity_tokenizer = copy.deepcopy(self.tokenizer)

self._pipe = pipeline(
task="text-classification",
model=self.model,
trust_remote_code=engine_args.trust_remote_code,
top_k=None,
revision=engine_args.revision,
tokenizer=self.tokenizer,
device=engine_args.device,
)

def encode_pre(self, sentences: list[str]):
return sentences

def encode_core(self, sentences: list[str]) -> dict:
outputs = self._pipe(sentences)
return outputs

def encode_post(self, classes) -> dict[str, float]:
"""runs post encoding such as normalization"""
return classes
Comment on lines +78 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: incorrect type hint - method returns list[list[dict]] based on test file, not dict[str, float]


def tokenize_lengths(self, sentences: list[str]) -> list[int]:
"""gets the lengths of each sentences according to tokenize/len etc."""
tks = self._infinity_tokenizer.batch_encode_plus(
sentences,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
return_length=False,
).encodings
return [len(t.tokens) for t in tks]
4 changes: 4 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from infinity_emb.primitives import InferenceEngine
from infinity_emb.transformer.audio.torch import TorchAudioModel
from infinity_emb.transformer.classifier.torch import SentenceClassifier
from infinity_emb.transformer.classifier.optimum import OptimumClassifier
from infinity_emb.transformer.crossencoder.optimum import OptimumCrossEncoder
from infinity_emb.transformer.crossencoder.torch import (
CrossEncoderPatched as CrossEncoderTorch,
Expand Down Expand Up @@ -87,11 +88,14 @@ def from_inference_engine(engine: InferenceEngine):

class PredictEngine(Enum):
torch = SentenceClassifier
optimum = OptimumClassifier

@staticmethod
def from_inference_engine(engine: InferenceEngine):
if engine == InferenceEngine.torch:
return PredictEngine.torch
elif engine == InferenceEngine.optimum:
return PredictEngine.optimum
else:
raise NotImplementedError(f"PredictEngine for {engine} not implemented")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from transformers import pipeline # type: ignore

from infinity_emb.args import EngineArgs
from infinity_emb.transformer.classifier.optimum import OptimumClassifier


def test_classifier(model_name: str = "SamLowe/roberta-base-go_emotions"):
model = OptimumClassifier(
engine_args=EngineArgs(
model_name_or_path=model_name,
device="cuda" if torch.cuda.is_available() else "cpu",
) # type: ignore
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Indentation is inconsistent in EngineArgs constructor

pipe = pipeline(model=model_name, task="text-classification")

sentences = ["This is awesome.", "I am depressed."]

encode_pre = model.encode_pre(sentences)
encode_core = model.encode_core(encode_pre)
preds = model.encode_post(encode_core)

assert len(preds) == len(sentences)
assert isinstance(preds, list)
assert isinstance(preds[0], list)
assert isinstance(preds[0][0], dict)
assert isinstance(preds[0][0]["label"], str)
assert isinstance(preds[0][0]["score"], float)
assert preds[0][0]["label"] == "admiration"
assert 0.98 > preds[0][0]["score"] > 0.93

preds_orig = pipe(sentences, top_k=None, truncation=True)

assert len(preds_orig) == len(preds)

for pred_orig, pred in zip(preds_orig, preds):
assert len(pred_orig) == len(pred)
for pred_orig_i, pred_i in zip(pred_orig[:5], pred[:5]):
assert abs(pred_orig_i["score"] - pred_i["score"]) < 0.05

if pred_orig_i["score"] > 0.005:
assert pred_orig_i["label"] == pred_i["label"]