Skip to content

Commit

Permalink
added prelu onnx operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 3, 2024
1 parent ab50143 commit 73ebba5
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 1 deletion.
Binary file added crates/burn-import/onnx-tests/tests/prelu/prelu.onnx
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/prelu/prelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: prelu.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu1 = nn.PReLU()

def forward(self, x):
x = self.relu1(x)
return x


def main():

# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "prelu.onnx"
test_input = torch.randn(2, 3, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data of ones: {}".format(test_input))
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

print("Test output: {}".format(output))


if __name__ == '__main__':
main()

4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::prelu::PReluNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
Expand Down Expand Up @@ -85,6 +86,7 @@ pub enum Node<PS: PrecisionSettings> {
Conv1d(Conv1dNode<PS>),
Conv2d(Conv2dNode<PS>),
ConvTranspose2d(ConvTranspose2dNode<PS>),
PRelu(PReluNode<PS>),
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
Expand All @@ -111,6 +113,7 @@ macro_rules! match_all {
Node::Conv1d(node) => $func(node),
Node::Conv2d(node) => $func(node),
Node::ConvTranspose2d(node) => $func(node),
Node::PRelu(node) => $func(node),
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Expand Down Expand Up @@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Conv1d(_) => "conv1d",
Node::Conv2d(_) => "conv2d",
Node::ConvTranspose2d(_) => "conv_transpose2d",
Node::PRelu(_) => "prelu",
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
Expand Down
100 changes: 100 additions & 0 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{Param, ParamId},
nn::{PReluConfig, PReluRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Clone, Debug)]
pub struct PReluNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub alpha: DataSerialize<PS::FloatElem>,
pub config: PReluConfig,
}

impl<PS: PrecisionSettings> PReluNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
alpha: DataSerialize<PS::FloatElem>,
config: PReluConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
PRelu<B>
},
),
input,
output,
alpha,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;

let num_parameters = self.config.num_parameters.to_tokens();
let alpha = self.config.alpha.to_tokens();
let tokens = quote! {
let #name = PReluConfig::new(#num_parameters, #alpha)
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = PReluRecord::<SerializationBackend> {
alpha: Param::initialized(
ParamId::new(),
Tensor::from_data(self.alpha.clone().convert(), &device),
),
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PRelu");
imports.register("burn::nn::prelu::PRelu");
imports.register("burn::nn::prelu::PReluConfig");
}

fn into_node(self) -> Node<PS> {
Node::PRelu(self)
}
}
17 changes: 16 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn::nn::{
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PReluConfig, PaddingConfig1d,
PaddingConfig2d,
};

Expand Down Expand Up @@ -120,6 +120,21 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
.with_padding(padding)
.with_dilation([dilations[0] as usize, dilations[1] as usize])
}
pub fn prelu_config(curr: &Node) -> PReluConfig {
let mut alpha = 0.01;
let mut num_parameters = 0;
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"alpha" => alpha = value.clone().into_f32(),
"num_parameters" => num_parameters = value.clone().into_i32(),
_ => {}
}
}

PReluConfig::new()
.with_num_parameters(num_parameters as usize)
.with_alpha(alpha as f64)
}

pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let mut attrs = curr.attrs.clone();
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::{
mask_where::WhereNode,
matmul::MatmulNode,
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -236,6 +237,7 @@ impl OnnxGraph {
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
Expand Down Expand Up @@ -695,6 +697,14 @@ impl OnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}

fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = prelu_config(&node);
let name = &node.name;
PReluNode::<PS>::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
Expand Down
Loading

0 comments on commit 73ebba5

Please sign in to comment.