From 8d226c7607e66822b588c65a934a5692737a9e30 Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Fri, 3 May 2024 21:12:12 +0530 Subject: [PATCH 1/5] added prelu onnx operator --- .../onnx-tests/tests/prelu/prelu.onnx | Bin 0 -> 172 bytes .../onnx-tests/tests/prelu/prelu.py | 49 +++++++++ crates/burn-import/src/burn/node/base.rs | 4 + crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/prelu.rs | 100 ++++++++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 17 ++- crates/burn-import/src/onnx/to_burn.rs | 10 ++ 7 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 crates/burn-import/onnx-tests/tests/prelu/prelu.onnx create mode 100644 crates/burn-import/onnx-tests/tests/prelu/prelu.py create mode 100644 crates/burn-import/src/burn/node/prelu.rs diff --git a/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx b/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d9644b84e5928e285e0e8db292545ab131abf537 GIT binary patch literal 172 zcmdtc@ { Conv1d(Conv1dNode), Conv2d(Conv2dNode), ConvTranspose2d(ConvTranspose2dNode), + PRelu(PReluNode), Dropout(DropoutNode), Gather(GatherNode), GlobalAvgPool(GlobalAvgPoolNode), @@ -111,6 +113,7 @@ macro_rules! match_all { Node::Conv1d(node) => $func(node), Node::Conv2d(node) => $func(node), Node::ConvTranspose2d(node) => $func(node), + Node::PRelu(node) => $func(node), Node::Dropout(node) => $func(node), Node::Gather(node) => $func(node), Node::GlobalAvgPool(node) => $func(node), @@ -147,6 +150,7 @@ impl Node { Node::Conv1d(_) => "conv1d", Node::Conv2d(_) => "conv2d", Node::ConvTranspose2d(_) => "conv_transpose2d", + Node::PRelu(_) => "prelu", Node::Dropout(_) => "dropout", Node::Gather(_) => "gather", Node::GlobalAvgPool(_) => "global_avg_pool", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 965652c7a2..ae936bbadb 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -17,6 +17,7 @@ pub(crate) mod linear; pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool2d; +pub(crate) mod prelu; pub(crate) mod reshape; pub(crate) mod unary; pub(crate) mod unsqueeze; diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs new file mode 100644 index 0000000000..7eb8446589 --- /dev/null +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -0,0 +1,100 @@ +use super::{Node, NodeCodegen, SerializationBackend}; +use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; +use burn::{ + module::{Param, ParamId}, + nn::{PReluConfig, PReluRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, +}; +use proc_macro2::TokenStream; +use quote::quote; +use serde::Serialize; + +#[derive(Clone, Debug)] +pub struct PReluNode { + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub alpha: DataSerialize, + pub config: PReluConfig, +} + +impl PReluNode { + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + alpha: DataSerialize, + config: PReluConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + PRelu + }, + ), + input, + output, + alpha, + config, + } + } +} + +impl NodeCodegen for PReluNode { + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self) -> Option { + let name = &self.field.name; + + let num_parameters = self.config.num_parameters.to_tokens(); + let alpha = self.config.alpha.to_tokens(); + let tokens = quote! { + let #name = PReluConfig::new(#num_parameters, #alpha) + .init(device); + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let device = Default::default(); + let record = PReluRecord:: { + alpha: Param::initialized( + ParamId::new(), + Tensor::from_data(self.alpha.clone().convert(), &device), + ), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); + } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PRelu"); + imports.register("burn::nn::prelu::PRelu"); + imports.register("burn::nn::prelu::PReluConfig"); + } + + fn into_node(self) -> Node { + Node::PRelu(self) + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index da213d55e7..e0c5cd16e2 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1,7 +1,7 @@ use burn::nn::{ conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig}, pool::{AvgPool2dConfig, MaxPool2dConfig}, - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d, + BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PReluConfig, PaddingConfig1d, PaddingConfig2d, }; @@ -120,6 +120,21 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_padding(padding) .with_dilation([dilations[0] as usize, dilations[1] as usize]) } +pub fn prelu_config(curr: &Node) -> PReluConfig { + let mut alpha = 0.01; + let mut num_parameters = 0; + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32(), + "num_parameters" => num_parameters = value.clone().into_i32(), + _ => {} + } + } + + PReluConfig::new() + .with_num_parameters(num_parameters as usize) + .with_alpha(alpha as f64) +} pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 51ebf8683b..ba7d8fbe15 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -30,6 +30,7 @@ use crate::{ mask_where::WhereNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, + prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -236,6 +237,7 @@ impl OnnxGraph { NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::PRelu => graph.register(Self::prelu_conversion::(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), @@ -695,6 +697,14 @@ impl OnnxGraph { MaxPool2dNode::new(name, input, output, config) } + fn prelu_conversion(node: Node) -> PReluNode { + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); + let weight = extract_data_serialize::(1, &node).unwrap(); + let config = prelu_config(&node); + let name = &node.name; + PReluNode::::new(name, input, output, weight, config) + } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); From e63fea4b95b74461e4a296b66b016ca48c3a58fa Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Fri, 3 May 2024 21:38:57 +0530 Subject: [PATCH 2/5] bug fix --- crates/burn-import/src/burn/node/prelu.rs | 1 - .../burn-import/src/onnx/op_configuration.rs | 18 +----------------- crates/burn-import/src/onnx/to_burn.rs | 3 ++- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 7eb8446589..32e3d3b52a 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -89,7 +89,6 @@ impl NodeCodegen for PReluNode { } } fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::PRelu"); imports.register("burn::nn::prelu::PRelu"); imports.register("burn::nn::prelu::PReluConfig"); } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e0c5cd16e2..1fb674e96b 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1,7 +1,7 @@ use burn::nn::{ conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig}, pool::{AvgPool2dConfig, MaxPool2dConfig}, - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PReluConfig, PaddingConfig1d, + BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d, }; @@ -120,22 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_padding(padding) .with_dilation([dilations[0] as usize, dilations[1] as usize]) } -pub fn prelu_config(curr: &Node) -> PReluConfig { - let mut alpha = 0.01; - let mut num_parameters = 0; - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32(), - "num_parameters" => num_parameters = value.clone().into_i32(), - _ => {} - } - } - - PReluConfig::new() - .with_num_parameters(num_parameters as usize) - .with_alpha(alpha as f64) -} - pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); let kernel_shape = attrs diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index ba7d8fbe15..3a28c2ecf4 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -5,6 +5,7 @@ use std::{ }; use burn::{ + nn::PReluConfig, record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, tensor::{DataSerialize, Element}, }; @@ -701,7 +702,7 @@ impl OnnxGraph { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); let weight = extract_data_serialize::(1, &node).unwrap(); - let config = prelu_config(&node); + let config = PReluConfig::new(); let name = &node.name; PReluNode::::new(name, input, output, weight, config) } From 7e0595680d2afa0e3a80d48b0be9399de1827ff3 Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Fri, 3 May 2024 21:59:00 +0530 Subject: [PATCH 3/5] added onnx tests and burn codegen tests --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 24 +++++++ crates/burn-import/src/burn/node/prelu.rs | 65 ++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 94b32209e9..3a2c9b685d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -39,6 +39,7 @@ fn main() { .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/prelu/prelu.onnx") .input("tests/reduce_max/reduce_max.onnx") .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reshape/reshape.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 34ddfa5f87..d6f259e682 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -47,6 +47,7 @@ include_models!( mul, neg, not, + prelu, recip, reduce_max, reduce_mean, @@ -658,6 +659,29 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn prelu() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: prelu::Model = prelu::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -1.122_856, -0.18632829], + ], + &device, + ); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -0.01122_856, -0.0018632829], + ]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn relu() { // Initialize the model without weights (because the exported file does not contain them) diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 32e3d3b52a..2ac5f8b85e 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -89,11 +89,72 @@ impl NodeCodegen for PReluNode { } } fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::prelu::PRelu"); - imports.register("burn::nn::prelu::PReluConfig"); + imports.register("burn::nn::PRelu"); + imports.register("burn::nn::PReluConfig"); } fn into_node(self) -> Node { Node::PRelu(self) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv1d::Conv1dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, + }; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(PReluNode::new( + "prelu", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + PReluConfig::new(), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::nn::prelu::PRelu; + use burn::nn::prelu::PReluConfig; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + #[derive(Module, Debug)] + pub struct Model { + prelu: PRelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let prelu = PReluConfig::new(1, 0.25).init(device); + Self { + prelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.prelu.forward(input); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} From 1b073affd4da7e3fa421b30ab5c791ecb27c6bed Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Sat, 4 May 2024 18:12:56 +0530 Subject: [PATCH 4/5] fix tests --- .../onnx-tests/tests/onnx_tests.rs | 2 +- crates/burn-import/src/burn/node/prelu.rs | 23 ++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index d6f259e682..c237542d41 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -676,7 +676,7 @@ mod tests { let output = model.forward(input); let expected = Data::from([ [0.33669037, 0.0, 0.23446237], - [0.23033303, -0.01122_856, -0.0018632829], + [0.23033303, -0.280714, -0.046582073], ]); assert_eq!(output.to_data(), expected); diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 2ac5f8b85e..a1e474d7d9 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -1,5 +1,5 @@ use super::{Node, NodeCodegen, SerializationBackend}; -use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; +use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; use burn::{ module::{Param, ParamId}, nn::{PReluConfig, PReluRecord}, @@ -55,11 +55,8 @@ impl NodeCodegen for PReluNode { fn field_init(&self) -> Option { let name = &self.field.name; - - let num_parameters = self.config.num_parameters.to_tokens(); - let alpha = self.config.alpha.to_tokens(); let tokens = quote! { - let #name = PReluConfig::new(#num_parameters, #alpha) + let #name = PReluConfig::new() .init(device); }; @@ -101,14 +98,8 @@ impl NodeCodegen for PReluNode { #[cfg(test)] mod tests { use super::*; - use crate::burn::{ - graph::BurnGraph, - node::{conv1d::Conv1dNode, test::assert_tokens}, - TensorType, - }; - use burn::{ - nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, - }; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; #[test] fn test_codegen() { @@ -125,8 +116,8 @@ mod tests { graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); let expected = quote! { - use burn::nn::prelu::PRelu; - use burn::nn::prelu::PReluConfig; + use burn::nn::PRelu; + use burn::nn::PReluConfig; use burn::{ module::Module, tensor::{backend::Backend, Tensor}, @@ -140,7 +131,7 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let prelu = PReluConfig::new(1, 0.25).init(device); + let prelu = PReluConfig::new().init(device); Self { prelu, phantom: core::marker::PhantomData, From 9e851b4903cd04c9f24ad7356cbe826012029a47 Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Sat, 4 May 2024 21:14:35 +0530 Subject: [PATCH 5/5] added prelu to supported onnx ops and add prelu to dim_inference --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/src/onnx/dim_inference.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index c03bbdad73..207fbd2812 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -126,7 +126,7 @@ represent the corresponding Burn Op. | [Or][119] | ❌ | ❌ | | [Pad][120] | ❌ | ✅ | | [Pow][121] | ✅ | ✅ | -| [PRelu][122] | ❌ | ✅ | +| [PRelu][122] | ✅ | ✅ | | [QLinearConv][123] | ❌ | ❌ | | [QLinearMatMul][124] | ❌ | ❌ | | [QuantizeLinear][125] | ❌ | ❌ | diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 9066f3eef3..e171553f15 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Pow => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), + NodeType::PRelu => same_as_input(node), NodeType::Where => where_update_outputs(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node),