Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PReLu ONNX import #1721

Merged
merged 5 commits into from
May 4, 2024
Merged

PReLu ONNX import #1721

merged 5 commits into from
May 4, 2024

Conversation

Arjun31415
Copy link
Contributor

@Arjun31415 Arjun31415 commented May 3, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1714

Changes

Added PReLu ONNX -> burn operator

Testing

Tested as specified in the Contributor book

@Arjun31415
Copy link
Contributor Author

Arjun31415 commented May 3, 2024

The rust code produced -

// Generated from ONNX "./onnx-tests/tests/prelu/prelu.onnx" by burn-import
use burn::nn::prelu::PRelu;
use burn::nn::prelu::PReluConfig;
use burn::nn::PRelu;
use burn::record::FullPrecisionSettings;
use burn::record::Recorder;
use burn::{
    module::Module,
    tensor::{backend::Backend, Tensor},
};


#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    prelu1: PRelu<B>,
    phantom: core::marker::PhantomData<B>,
    device: burn::module::Ignored<B::Device>,
}


impl<B: Backend> Default for Model<B> {
    fn default() -> Self {
        Self::from_file("./out/prelu", &Default::default())
    }
}

impl<B: Backend> Model<B> {
    pub fn from_file(file: &str, device: &B::Device) -> Self {
        let record = burn::record::PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
            .load(file.into(), device)
            .expect("Record file to exist.");
        Self::new(device).load_record(record)
    }
}

impl<B: Backend> Model<B> {
    #[allow(unused_variables)]
    pub fn new(device: &B::Device) -> Self {
        let prelu1 = PReluConfig::new(0, 0.009999999776482582).init(device);
        Self {
            prelu1,
            phantom: core::marker::PhantomData,
            device: burn::module::Ignored(device.clone()),
        }
    }

    #[allow(clippy::let_and_return, clippy::approx_constant)]
    pub fn forward(&self, input1: Tensor<B, 2>) -> Tensor<B, 0> {
        let prelu1_out1 = self.prelu1.forward(input1);
        prelu1_out1
    }
}

I think the PReluConfig initialization is wrong .

        let prelu1 = PReluConfig::new(0, 0.009999999776482582).init(device);

the number of parameters should not be 0 and initial alpha should be 0.01 exactly and not this approximate value

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start but we are missing two types of unit tests (under onnx-tests and in crates/burn-import/src/burn/node/prelu.rs). Also Prelu ONNX spec does not have alpha and numb_parameters, so there is no need to extract configuration.

crates/burn-import/src/burn/node/prelu.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/op_configuration.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/to_burn.rs Outdated Show resolved Hide resolved
@antimora
Copy link
Collaborator

antimora commented May 3, 2024

FYI, to run ONNX test:

[burn-import]% cd onnx-tests
[onnx-tests]% cargo test

to run code gen tests:

[burn-import]% cargo test

Copy link

codecov bot commented May 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.47%. Comparing base (ab50143) to head (9e851b4).
Report is 8 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1721      +/-   ##
==========================================
- Coverage   86.54%   86.47%   -0.07%     
==========================================
  Files         696      698       +2     
  Lines       81653    82874    +1221     
==========================================
+ Hits        70664    71667    +1003     
- Misses      10989    11207     +218     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Arjun31415
Copy link
Contributor Author

I also get this warning

DEBUG burn_import::onnx::from_onnx: Number of initializers: 1    
DEBUG burn_import::onnx::from_onnx: Number of outputs: 1    
DEBUG burn_import::onnx::proto_conversion: Converting ONNX node with type "PRelu"    
DEBUG burn_import::onnx::from_onnx: renaming node "/relu1/PRelu"    
 WARN burn_import::onnx::dim_inference: Must implement dimension inference for PRelu    
DEBUG burn_import::onnx::from_onnx: checking inputs for node "prelu1"    
DEBUG burn_import::onnx::from_onnx: 

@antimora
Copy link
Collaborator

antimora commented May 3, 2024

I also get this warning

DEBUG burn_import::onnx::from_onnx: Number of initializers: 1    
DEBUG burn_import::onnx::from_onnx: Number of outputs: 1    
DEBUG burn_import::onnx::proto_conversion: Converting ONNX node with type "PRelu"    
DEBUG burn_import::onnx::from_onnx: renaming node "/relu1/PRelu"    
 WARN burn_import::onnx::dim_inference: Must implement dimension inference for PRelu    
DEBUG burn_import::onnx::from_onnx: checking inputs for node "prelu1"    
DEBUG burn_import::onnx::from_onnx: 

Yeah, you will need to add dim inference in dim_inference.rs file. You can use same as input function since the output will have the same dimensions.

@antimora
Copy link
Collaborator

antimora commented May 3, 2024

Also please update the supported https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md as well

@Arjun31415
Copy link
Contributor Author

Ok so I think what is pending is the dim_inference and loading the weight from the onnx into alpha of the Burn PreluConfig. I need help with this, where exactly do I do the latter part

@Arjun31415
Copy link
Contributor Author

if I do those two then tests would also be fixed

@Arjun31415
Copy link
Contributor Author

Ok I think dim inference is also fixed. Now I just need to load the prelu weight into burn

@Arjun31415
Copy link
Contributor Author

Ok I think it is done

@Arjun31415 Arjun31415 requested a review from antimora May 4, 2024 15:43
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Thank you for your contribution to the Burn project!

@antimora antimora merged commit 152509c into tracel-ai:main May 4, 2024
14 checks passed
@Arjun31415 Arjun31415 deleted the prelu-onnx branch May 5, 2024 03:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants