-
Notifications
You must be signed in to change notification settings - Fork 449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PReLu ONNX import #1721
PReLu ONNX import #1721
Conversation
The rust code produced - // Generated from ONNX "./onnx-tests/tests/prelu/prelu.onnx" by burn-import
use burn::nn::prelu::PRelu;
use burn::nn::prelu::PReluConfig;
use burn::nn::PRelu;
use burn::record::FullPrecisionSettings;
use burn::record::Recorder;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
prelu1: PRelu<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_file("./out/prelu", &Default::default())
}
}
impl<B: Backend> Model<B> {
pub fn from_file(file: &str, device: &B::Device) -> Self {
let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.load(file.into(), device)
.expect("Record file to exist.");
Self::new(device).load_record(record)
}
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let prelu1 = PReluConfig::new(0, 0.009999999776482582).init(device);
Self {
prelu1,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input1: Tensor<B, 2>) -> Tensor<B, 0> {
let prelu1_out1 = self.prelu1.forward(input1);
prelu1_out1
}
} I think the let prelu1 = PReluConfig::new(0, 0.009999999776482582).init(device); the number of parameters should not be 0 and initial alpha should be 0.01 exactly and not this approximate value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start but we are missing two types of unit tests (under onnx-tests and in crates/burn-import/src/burn/node/prelu.rs). Also Prelu ONNX spec does not have alpha and numb_parameters, so there is no need to extract configuration.
FYI, to run ONNX test: [burn-import]% cd onnx-tests
[onnx-tests]% cargo test to run code gen tests: [burn-import]% cargo test |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1721 +/- ##
==========================================
- Coverage 86.54% 86.47% -0.07%
==========================================
Files 696 698 +2
Lines 81653 82874 +1221
==========================================
+ Hits 70664 71667 +1003
- Misses 10989 11207 +218 ☔ View full report in Codecov by Sentry. |
I also get this warning
|
Yeah, you will need to add dim inference in dim_inference.rs file. You can use same as input function since the output will have the same dimensions. |
Also please update the supported https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md as well |
Ok so I think what is pending is the |
if I do those two then tests would also be fixed |
Ok I think dim inference is also fixed. Now I just need to load the prelu weight into burn |
Ok I think it is done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Thank you for your contribution to the Burn project!
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#1714
Changes
Added PReLu ONNX -> burn operator
Testing
Tested as specified in the Contributor book