Skip to content

Commit

Permalink
Add MaxPool1d ONNX Op(#1725)
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 authored May 6, 2024
1 parent fb13503 commit 7f94f4c
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 9 deletions.
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/pool/max_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::tensor::Tensor;
use burn_tensor::module::max_pool1d;

/// Configuration to create a [1D max pooling](MaxPool1d) layer.
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MaxPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ represent the corresponding Burn Op.
| [MatMul][94] |||
| [MatMulInteger][95] |||
| [Max][96] |||
| [MaxPool1d][97] | ||
| [MaxPool1d][97] | ||
| [MaxPool2d][98] |||
| [MaxRoiPool][99] |||
| [MaxUnpool][100] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
Expand Down
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/maxpool1d/maxpool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: maxpool2d1.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.maxpool = nn.MaxPool1d(5, stride=2, padding=2, dilation=1)

def forward(self, x):
x = self.maxpool(x)
return x


def main():
# Set seed for reproducibility
torch.manual_seed(42)

# Print options
torch.set_printoptions(precision=3)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "maxpool1d.onnx"
test_input = torch.randn(1, 5, 5, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
print("Test input data of ones: {}".format(test_input))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))


if __name__ == '__main__':
main()

26 changes: 26 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include_models!(
log,
mask_where,
matmul,
maxpool1d,
maxpool2d,
mul,
neg,
Expand Down Expand Up @@ -444,6 +445,31 @@ mod tests {
assert_eq!(output1.to_data(), expected1);
assert_eq!(output2, expected2);
}
#[test]
fn maxpool1d() {
let device = Default::default();

let model: maxpool1d::Model<Backend> = maxpool1d::Model::new(&device);
let input = Tensor::<Backend, 3>::from_floats(
[[
[1.927, 1.487, 0.901, -2.106, 0.678],
[-1.235, -0.043, -1.605, -0.752, -0.687],
[-0.493, 0.241, -1.111, 0.092, -2.317],
[-0.217, -1.385, -0.396, 0.803, -0.622],
[-0.592, -0.063, -0.829, 0.331, -1.558],
]],
&device,
);
let output = model.forward(input);
let expected = Data::from([[
[1.927, 1.927, 0.901],
[-0.043, -0.043, -0.687],
[0.241, 0.241, 0.092],
[-0.217, 0.803, 0.803],
[-0.063, 0.331, 0.331],
]]);
assert_eq!(output.to_data(), expected);
}

#[test]
fn maxpool2d() {
Expand Down
13 changes: 7 additions & 6 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::prelu::PReluNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, linear::LinearNode, matmul::MatmulNode,
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -93,6 +91,7 @@ pub enum Node<PS: PrecisionSettings> {
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Reshape(ReshapeNode),
Unary(UnaryNode),
Expand Down Expand Up @@ -120,6 +119,7 @@ macro_rules! match_all {
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Unary(node) => $func(node),
Expand Down Expand Up @@ -157,6 +157,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Reshape(_) => "reshape",
Node::Unary(unary) => unary.kind.as_str(),
Expand Down
158 changes: 158 additions & 0 deletions crates/burn-import/src/burn/node/max_pool1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::MaxPool1dConfig, record::PrecisionSettings};

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};

#[derive(Debug, Clone)]
pub struct MaxPool1dNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub config: MaxPool1dConfig,
}

impl MaxPool1dNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
config: MaxPool1dConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
MaxPool1d
},
),
input,
output,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool1dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.stride.to_tokens();
let padding = self.config.padding.to_tokens();
let dilation = self.config.dilation.to_tokens();
let tokens = quote! {
let #name = MaxPool1dConfig::new(#kernel_size)
.with_stride(#strides)
.with_padding(#padding)
.with_dilation(#dilation)
.init();
};

Some(tokens)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}

fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PaddingConfig1d");
imports.register("burn::nn::pool::MaxPool1d");
imports.register("burn::nn::pool::MaxPool1dConfig");
}

fn into_node(self) -> Node<PS> {
Node::MaxPool1d(self)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{
nn::{pool::MaxPool1dConfig, PaddingConfig1d},
record::FullPrecisionSettings,
};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(MaxPool1dNode::new(
"max_pool1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_dilation(1),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::PaddingConfig1d;
use burn::nn::pool::MaxPool1d;
use burn::nn::pool::MaxPool1dConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
max_pool1d: MaxPool1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let max_pool1d = MaxPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_dilation(1)
.init();

Self {
max_pool1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.max_pool1d.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
Expand Down
29 changes: 28 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn::nn::{
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
pool::{AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
PaddingConfig2d,
};
Expand Down Expand Up @@ -96,6 +96,33 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
.with_padding(padding)
}

/// Create a MaxPool2dConfig from the attributes of the node
pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig {
let mut kernel_shape = Vec::new();
let mut stride = vec![1];
let mut pads = vec![0, 0];
let mut dilation = vec![1];

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => stride = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilation = value.clone().into_i64s(),
_ => {}
}
}
assert_eq!(kernel_shape.len(), 1);
assert_eq!(dilation.len(), 1);
assert_eq!(stride.len(), 1);
let padding = padding_config_1d(&pads);

MaxPool1dConfig::new(kernel_shape[0] as usize)
.with_stride(stride[0] as usize)
.with_padding(padding)
.with_dilation(dilation[0] as usize)
}

/// Create a MaxPool2dConfig from the attributes of the node
pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let mut kernel_shape = Vec::new();
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::{
linear::LinearNode,
mask_where::WhereNode,
matmul::MatmulNode,
max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
Expand Down Expand Up @@ -239,6 +240,7 @@ impl OnnxGraph {
NodeType::Cos => graph.register(Self::cos_conversion(node)),
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
Expand Down Expand Up @@ -703,6 +705,14 @@ impl OnnxGraph {
let name = &node.name;
Conv2dNode::<PS>::new(name, input, output, weight, bias, config)
}
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let config = max_pool1d_config(&node);

let name = &node.name;
MaxPool1dNode::new(name, input, output, config)
}

fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
Expand Down

0 comments on commit 7f94f4c

Please sign in to comment.