Skip to content

Commit

Permalink
Implement ONNX Pad Operator (#2007)
Browse files Browse the repository at this point in the history
* Implement ONNX pad

* ONNX pad arguments fix

pad now requires 2 or more arguments
if the third argument is not given, it will default to 0

* fixing bug in input len fix

* change panic comment

Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument.

---------

Co-authored-by: JC <[email protected]>
Co-authored-by: mepatrick73 <[email protected]>
  • Loading branch information
3 people authored Jul 23, 2024
1 parent 53c77ae commit 4a3fc9d
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 4 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/pad/pad.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
Expand Down
21 changes: 21 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include_models!(
mul,
neg,
not,
pad,
greater,
greater_or_equal,
less,
Expand Down Expand Up @@ -1407,6 +1408,26 @@ mod tests {
output.assert_eq(&expected, true);
}

#[test]
fn pad() {
let device = Default::default();
let model: pad::Model<Backend> = pad::Model::new(&device);

let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 1., 2., 0., 0., 0., 0.],
[0.0_f32, 0., 3., 4., 0., 0., 0., 0.],
[0.0_f32, 0., 5., 6., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
]);

output.assert_eq(&expected, true);
}

#[test]
fn greater() {
let device = Default::default();
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/pad/pad.onnx
Binary file not shown.
158 changes: 158 additions & 0 deletions crates/burn-import/onnx-tests/tests/pad/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/pad/pad.onnx

### Helper Functions ###
from pathlib import Path
from typing import Any
import numpy
from numpy.core.multiarray import dtype
import onnx
from onnx import ModelProto, TensorProto, ValueInfoProto
from onnx.reference import ReferenceEvaluator
from onnx.checker import check_model
from onnx.helper import (
make_model,
make_node,
make_graph,
)


def build_test_save(
name: str,
inputs: list[ValueInfoProto],
outputs: list[ValueInfoProto],
initializers: list[TensorProto] = [],
attributes: dict[str, Any] = {},
) -> None:
node_inputs = [input.name for input in inputs + initializers]
node_outputs = [output.name for output in outputs]

node = make_node(
name.capitalize(),
inputs=node_inputs,
outputs=node_outputs,
**attributes,
)

graph = make_graph(
nodes=[node],
name=f"{name.capitalize()}Graph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)

onnx_model = make_model(graph)
check_model(onnx_model)

run_tests(onnx_model)

onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx"))


class TestCase:
def __init__(
self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray
):
self.name = name
self.feeds = feeds
self.expected = expected

def test_model(self, model: ModelProto):
sess = ReferenceEvaluator(model)

result = numpy.array(sess.run(None, self.feeds))

if not numpy.array_equal(result, self.expected):
print(
f"""{self.name}
Expected result: {self.expected}
Got: {result}"""
)
raise Exception("Test failed")


def test_positive_pads(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2)
pads = numpy.array([1, 2, 3, 4], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array(
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
]
)

TestCase("test_positive_constant_pads", feeds, expected).test_model(model)


def test_1d_input(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 5, dtype="float32")
pads = numpy.array([1, 2], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]])

TestCase("test_1d_input", feeds, expected).test_model(model)


def run_tests(model: ModelProto) -> None:
test_positive_pads(model)
test_1d_input(model)
# TODO: test_negative_pads
# TODO: support other modes: reflect, edge, wrap


### Helper Functions End ###

import numpy
from onnx import TensorProto, numpy_helper
from onnx.helper import make_tensor_value_info


def get_initializers() -> list[TensorProto]:
pads = numpy_helper.from_array(
numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads"
)
constant_value = numpy_helper.from_array(
numpy.array([0.0]).astype(numpy.float32), name="constant_value"
)

return [pads, constant_value]


def main() -> None:
name = "pad"

inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])]
outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])]
initializers = get_initializers()

build_test_save(
name=name,
inputs=inputs,
outputs=outputs,
initializers=initializers,
attributes={"mode": "constant"},
)


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Resize(ResizeNode),
Expand Down Expand Up @@ -150,6 +151,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
Expand Down Expand Up @@ -203,6 +205,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
Expand Down
104 changes: 104 additions & 0 deletions crates/burn-import/src/burn/node/pad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::str::FromStr;

use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct PadConfig {
pub pads: Vec<usize>,
pub constant_value: f32,
}

#[derive(Debug, Clone, new)]
pub struct PadNode {
pub input: TensorType,
pub output: TensorType,
pub config: PadConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let pads = self.config.pads.iter().map(|p| p.to_tokens());
let constant_value_string = format!("{}_f32.elem()", self.config.constant_value);
let constant_value = TokenStream::from_str(&constant_value_string).unwrap();

quote! {
let #output = #input.pad((#(#pads),*), #constant_value);
}
}
fn into_node(self) -> Node<PS> {
Node::Pad(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{pad::PadNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_pad() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = PadConfig::new(vec![1, 2, 3, 4], -1.0);
graph.register(PadNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::ElementConversion;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.pad((1, 2, 3, 4), -1_f32.elem());
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
Loading

0 comments on commit 4a3fc9d

Please sign in to comment.