diff --git a/libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py b/libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py new file mode 100644 index 00000000..51edef0f --- /dev/null +++ b/libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2023-now michaelfeil + +import copy +import os + +from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS +from infinity_emb.args import EngineArgs +from infinity_emb.transformer.abstract import BaseClassifer +from infinity_emb.transformer.utils_optimum import ( + device_to_onnx, + get_onnx_files, + optimize_model, +) + +if CHECK_ONNXRUNTIME.is_available: + try: + from optimum.onnxruntime import ( # type: ignore[import-untyped] + ORTModelForSequenceClassification, + ) + + except (ImportError, RuntimeError, Exception) as ex: + CHECK_ONNXRUNTIME.mark_dirty(ex) + +if CHECK_TRANSFORMERS.is_available: + from transformers import AutoTokenizer, pipeline # type: ignore[import-untyped] + + +class OptimumClassifier(BaseClassifer): + def __init__(self, *, engine_args: EngineArgs): + CHECK_ONNXRUNTIME.mark_required() + CHECK_TRANSFORMERS.mark_required() + provider = device_to_onnx(engine_args.device) + + onnx_file = get_onnx_files( + model_name_or_path=engine_args.model_name_or_path, + revision=engine_args.revision, + use_auth_token=True, + prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()), + ) + + self.model = optimize_model( + model_name_or_path=engine_args.model_name_or_path, + model_class=ORTModelForSequenceClassification, + revision=engine_args.revision, + trust_remote_code=engine_args.trust_remote_code, + execution_provider=provider, + file_name=onnx_file.as_posix(), + optimize_model=not os.environ.get("INFINITY_ONNX_DISABLE_OPTIMIZE", False), + ) + self.model.use_io_binding = False + + self.tokenizer = AutoTokenizer.from_pretrained( + engine_args.model_name_or_path, + revision=engine_args.revision, + trust_remote_code=engine_args.trust_remote_code, + ) + + self._infinity_tokenizer = copy.deepcopy(self.tokenizer) + + self._pipe = pipeline( + task="text-classification", + model=self.model, + trust_remote_code=engine_args.trust_remote_code, + top_k=None, + revision=engine_args.revision, + tokenizer=self.tokenizer, + device=engine_args.device, + ) + + def encode_pre(self, sentences: list[str]): + return sentences + + def encode_core(self, sentences: list[str]) -> dict: + outputs = self._pipe(sentences) + return outputs + + def encode_post(self, classes) -> dict[str, float]: + """runs post encoding such as normalization""" + return classes + + def tokenize_lengths(self, sentences: list[str]) -> list[int]: + """gets the lengths of each sentences according to tokenize/len etc.""" + tks = self._infinity_tokenizer.batch_encode_plus( + sentences, + add_special_tokens=False, + return_token_type_ids=False, + return_attention_mask=False, + return_length=False, + ).encodings + return [len(t.tokens) for t in tks] diff --git a/libs/infinity_emb/infinity_emb/transformer/utils.py b/libs/infinity_emb/infinity_emb/transformer/utils.py index 5ed56786..e17de3a5 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils.py @@ -7,6 +7,7 @@ from infinity_emb.primitives import InferenceEngine from infinity_emb.transformer.audio.torch import TorchAudioModel from infinity_emb.transformer.classifier.torch import SentenceClassifier +from infinity_emb.transformer.classifier.optimum import OptimumClassifier from infinity_emb.transformer.crossencoder.optimum import OptimumCrossEncoder from infinity_emb.transformer.crossencoder.torch import ( CrossEncoderPatched as CrossEncoderTorch, @@ -87,11 +88,14 @@ def from_inference_engine(engine: InferenceEngine): class PredictEngine(Enum): torch = SentenceClassifier + optimum = OptimumClassifier @staticmethod def from_inference_engine(engine: InferenceEngine): if engine == InferenceEngine.torch: return PredictEngine.torch + elif engine == InferenceEngine.optimum: + return PredictEngine.optimum else: raise NotImplementedError(f"PredictEngine for {engine} not implemented") diff --git a/libs/infinity_emb/tests/unit_test/transformer/classifier/test_optimum_classifier.py b/libs/infinity_emb/tests/unit_test/transformer/classifier/test_optimum_classifier.py new file mode 100644 index 00000000..386c3061 --- /dev/null +++ b/libs/infinity_emb/tests/unit_test/transformer/classifier/test_optimum_classifier.py @@ -0,0 +1,49 @@ +import torch +from optimum.pipelines import pipeline # type: ignore +from optimum.onnxruntime import ORTModelForSequenceClassification +from infinity_emb.args import EngineArgs +from infinity_emb.transformer.classifier.optimum import OptimumClassifier + + +def test_classifier(model_name: str = "SamLowe/roberta-base-go_emotions-onnx"): + model = OptimumClassifier( + engine_args=EngineArgs( + model_name_or_path=model_name, + device="cuda" if torch.cuda.is_available() else "cpu", + ) # type: ignore + ) + + pipe = pipeline( + task="text-classification", + model=ORTModelForSequenceClassification.from_pretrained( + model_name, file_name="onnx/model_quantized.onnx" + ), + top_k=None, + ) + + sentences = ["This is awesome.", "I am depressed."] + + encode_pre = model.encode_pre(sentences) + encode_core = model.encode_core(encode_pre) + preds = model.encode_post(encode_core) + + assert len(preds) == len(sentences) + assert isinstance(preds, list) + assert isinstance(preds[0], list) + assert isinstance(preds[0][0], dict) + assert isinstance(preds[0][0]["label"], str) + assert isinstance(preds[0][0]["score"], float) + assert preds[0][0]["label"] == "admiration" + assert 0.98 > preds[0][0]["score"] > 0.93 + + preds_orig = pipe(sentences, top_k=None, truncation=True) + + assert len(preds_orig) == len(preds) + + for pred_orig, pred in zip(preds_orig, preds): + assert len(pred_orig) == len(pred) + for pred_orig_i, pred_i in zip(pred_orig[:5], pred[:5]): + assert abs(pred_orig_i["score"] - pred_i["score"]) < 0.05 + + if pred_orig_i["score"] > 0.005: + assert pred_orig_i["label"] == pred_i["label"]