-
Notifications
You must be signed in to change notification settings - Fork 446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/onnx argmax #1814
Feature/onnx argmax #1814
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1814 +/- ##
==========================================
- Coverage 86.43% 86.42% -0.02%
==========================================
Files 753 761 +8
Lines 87602 87987 +385
==========================================
+ Hits 75723 76041 +318
- Misses 11879 11946 +67 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for tackling the ONNX track!
Overall, the implementation looks good 🙂 Just some minor comments/changes to complete the PR.
Hi laggui, thanks for the feedback. I addressed those points (hopefully I didn't miss any) and set any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a minor request just to limit the warning based on the value for select_last_index
.
Otherwise, looks good! So I'll approve in advance and we can merge when addressed.
"select_last_index" => log::warn!( | ||
"select_last_index param for argmax is ignored in burn (got {:?})", | ||
value | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can still capture the select_last_index
value here, but only warn if it is 1 (because the default implementation pretty much everywhere including Burn is to return the first max value, not last).
Should be all good now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!!
* pre-test * implementing argmax for burn-import from onnx * tidying * fixing return types and tests * addressing feedback * only warn when select_last_index!=0
ArgMax Onnx Op for Burn-Import
Checklist
run-checks all
script has been executed.Related Issues/PRs
Original Issue
Changes
crates/burn-import/onnx/tests/argmax/argmax.py
with a simple model that calls.argmax(dim)
on an input tensor, and onnx code to convert the model into a ONNX fileArgMaxNode
incrates/burn-import/src/burn/node/argmax.rs
to store ArgMax functionality. Also added necessary helper functionsonnx_tests.rs
Testing
A test called
argmax
inonnx_tests.rs
checks that the created model is able to generate the correct argmax outputs.One thing I did notice was that the only params implemented for the existing tensor argmax function are the dim/axis - keepdims and select_last_index, which are also params for the ONNX argmax node, don't seem to exist in burn. I set them to their defauls, however this may cause issues if trying to import an ONNX model where e.g. keepdims=false. Happy to hear people's input on this.