-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding optimum option for PredictEngine #492
Conversation
…ves the same output as torch implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Summary
This PR adds ONNX/Optimum support for text classification models, enabling optimized inference through the PredictEngine with ONNX runtime integration.
- Added new
OptimumClassifier
class in/libs/infinity_emb/infinity_emb/transformer/classifier/optimum.py
with ONNX runtime support and model optimization capabilities - Implemented device-aware quantization preferences in
OptimumClassifier
for CPU and OpenVINO providers - Added comprehensive test suite in
/test_optimum_classifier.py
comparing against HuggingFace pipeline outputs - Extended
PredictEngine
enum inutils.py
to include Optimum as a new inference engine option - Disabled IO binding in ONNX runtime for better compatibility via
model.use_io_binding = False
💡 (1/5) You can manually trigger the bot by mentioning @greptileai in a comment!
3 file(s) reviewed, 4 comment(s)
Edit PR Review Bot Settings | Greptile
engine_args=EngineArgs( | ||
model_name_or_path=model_name, | ||
device="cuda" if torch.cuda.is_available() else "cpu", | ||
) # type: ignore | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Indentation is inconsistent in EngineArgs constructor
import copy | ||
import os | ||
|
||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: numpy is imported but never used in this file
def encode_post(self, classes) -> dict[str, float]: | ||
"""runs post encoding such as normalization""" | ||
return classes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: incorrect type hint - method returns list[list[dict]] based on test file, not dict[str, float]
if CHECK_TRANSFORMERS.is_available: | ||
from transformers import AutoConfig, AutoTokenizer, pipeline # type: ignore[import-untyped] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: AutoConfig is imported but never used
Thanks for opening this & opening an issue before! Looks good so far. There is a “make precommit” command in the malefile. |
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #492 +/- ##
==========================================
+ Coverage 79.53% 79.58% +0.05%
==========================================
Files 41 42 +1
Lines 3430 3468 +38
==========================================
+ Hits 2728 2760 +32
- Misses 702 708 +6 ☔ View full report in Codecov by Sentry. |
Thanks for all the work! |
Related Issue
Resolves #488
Checklist
Additional Notes
Add any other context about the PR here.