Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ONNX Pad Operator #2007

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/pad/pad.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
Expand Down
21 changes: 21 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include_models!(
mul,
neg,
not,
pad,
greater,
greater_or_equal,
less,
Expand Down Expand Up @@ -1406,6 +1407,26 @@ mod tests {
output.assert_eq(&expected, true);
}

#[test]
fn pad() {
let device = Default::default();
let model: pad::Model<Backend> = pad::Model::new(&device);

let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 1., 2., 0., 0., 0., 0.],
[0.0_f32, 0., 3., 4., 0., 0., 0., 0.],
[0.0_f32, 0., 5., 6., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
]);

output.assert_eq(&expected, true);
}

#[test]
fn greater() {
let device = Default::default();
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/pad/pad.onnx
Binary file not shown.
158 changes: 158 additions & 0 deletions crates/burn-import/onnx-tests/tests/pad/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/pad/pad.onnx

### Helper Functions ###
from pathlib import Path
from typing import Any
import numpy
from numpy.core.multiarray import dtype
import onnx
from onnx import ModelProto, TensorProto, ValueInfoProto
from onnx.reference import ReferenceEvaluator
from onnx.checker import check_model
from onnx.helper import (
make_model,
make_node,
make_graph,
)


def build_test_save(
name: str,
inputs: list[ValueInfoProto],
outputs: list[ValueInfoProto],
initializers: list[TensorProto] = [],
attributes: dict[str, Any] = {},
) -> None:
node_inputs = [input.name for input in inputs + initializers]
node_outputs = [output.name for output in outputs]

node = make_node(
name.capitalize(),
inputs=node_inputs,
outputs=node_outputs,
**attributes,
)

graph = make_graph(
nodes=[node],
name=f"{name.capitalize()}Graph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)

onnx_model = make_model(graph)
check_model(onnx_model)

run_tests(onnx_model)

onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx"))


class TestCase:
def __init__(
self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray
):
self.name = name
self.feeds = feeds
self.expected = expected

def test_model(self, model: ModelProto):
sess = ReferenceEvaluator(model)

result = numpy.array(sess.run(None, self.feeds))

if not numpy.array_equal(result, self.expected):
print(
f"""{self.name}
Expected result: {self.expected}
Got: {result}"""
)
raise Exception("Test failed")


def test_positive_pads(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2)
pads = numpy.array([1, 2, 3, 4], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array(
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
]
)

TestCase("test_positive_constant_pads", feeds, expected).test_model(model)


def test_1d_input(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 5, dtype="float32")
pads = numpy.array([1, 2], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]])

TestCase("test_1d_input", feeds, expected).test_model(model)


def run_tests(model: ModelProto) -> None:
test_positive_pads(model)
test_1d_input(model)
# TODO: test_negative_pads
# TODO: support other modes: reflect, edge, wrap


### Helper Functions End ###

import numpy
from onnx import TensorProto, numpy_helper
from onnx.helper import make_tensor_value_info


def get_initializers() -> list[TensorProto]:
pads = numpy_helper.from_array(
numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads"
)
constant_value = numpy_helper.from_array(
numpy.array([0.0]).astype(numpy.float32), name="constant_value"
)

return [pads, constant_value]


def main() -> None:
name = "pad"

inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])]
outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])]
initializers = get_initializers()

build_test_save(
name=name,
inputs=inputs,
outputs=outputs,
initializers=initializers,
attributes={"mode": "constant"},
)


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Resize(ResizeNode),
Expand Down Expand Up @@ -150,6 +151,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
Expand Down Expand Up @@ -203,6 +205,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
Expand Down
104 changes: 104 additions & 0 deletions crates/burn-import/src/burn/node/pad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::str::FromStr;

use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct PadConfig {
pub pads: Vec<usize>,
pub constant_value: f32,
}

#[derive(Debug, Clone, new)]
pub struct PadNode {
pub input: TensorType,
pub output: TensorType,
pub config: PadConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let pads = self.config.pads.iter().map(|p| p.to_tokens());
let constant_value_string = format!("{}_f32.elem()", self.config.constant_value);
let constant_value = TokenStream::from_str(&constant_value_string).unwrap();

quote! {
let #output = #input.pad((#(#pads),*), #constant_value);
}
}
fn into_node(self) -> Node<PS> {
Node::Pad(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{pad::PadNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_pad() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = PadConfig::new(vec![1, 2, 3, 4], -1.0);
graph.register(PadNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::ElementConversion;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.pad((1, 2, 3, 4), -1_f32.elem());
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
Loading
Loading