diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 497400369f..f10ee38619 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -43,6 +43,7 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/pad/pad.onnx") .input("tests/expand/expand.onnx") .input("tests/greater/greater.onnx") .input("tests/greater_or_equal/greater_or_equal.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index d5cc470f06..5cafd3e4cb 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -55,6 +55,7 @@ include_models!( mul, neg, not, + pad, greater, greater_or_equal, less, @@ -1406,6 +1407,26 @@ mod tests { output.assert_eq(&expected, true); } + #[test] + fn pad() { + let device = Default::default(); + let model: pad::Model = pad::Model::new(&device); + + let input = Tensor::::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device); + let output = model.forward(input).to_data(); + let expected = TensorData::from([ + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 1., 2., 0., 0., 0., 0.], + [0.0_f32, 0., 3., 4., 0., 0., 0., 0.], + [0.0_f32, 0., 5., 6., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + [0.0_f32, 0., 0., 0., 0., 0., 0., 0.], + ]); + + output.assert_eq(&expected, true); + } + #[test] fn greater() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.onnx b/crates/burn-import/onnx-tests/tests/pad/pad.onnx new file mode 100644 index 0000000000..4c8c265c42 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/pad/pad.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/pad/pad.py b/crates/burn-import/onnx-tests/tests/pad/pad.py new file mode 100755 index 0000000000..0600b89ce7 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/pad/pad.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/pad/pad.onnx + +### Helper Functions ### +from pathlib import Path +from typing import Any +import numpy +from numpy.core.multiarray import dtype +import onnx +from onnx import ModelProto, TensorProto, ValueInfoProto +from onnx.reference import ReferenceEvaluator +from onnx.checker import check_model +from onnx.helper import ( + make_model, + make_node, + make_graph, +) + + +def build_test_save( + name: str, + inputs: list[ValueInfoProto], + outputs: list[ValueInfoProto], + initializers: list[TensorProto] = [], + attributes: dict[str, Any] = {}, +) -> None: + node_inputs = [input.name for input in inputs + initializers] + node_outputs = [output.name for output in outputs] + + node = make_node( + name.capitalize(), + inputs=node_inputs, + outputs=node_outputs, + **attributes, + ) + + graph = make_graph( + nodes=[node], + name=f"{name.capitalize()}Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + + onnx_model = make_model(graph) + check_model(onnx_model) + + run_tests(onnx_model) + + onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx")) + + +class TestCase: + def __init__( + self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray + ): + self.name = name + self.feeds = feeds + self.expected = expected + + def test_model(self, model: ModelProto): + sess = ReferenceEvaluator(model) + + result = numpy.array(sess.run(None, self.feeds)) + + if not numpy.array_equal(result, self.expected): + print( + f"""{self.name} +Expected result: {self.expected} +Got: {result}""" + ) + raise Exception("Test failed") + + +def test_positive_pads(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2) + pads = numpy.array([1, 2, 3, 4], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + + TestCase("test_positive_constant_pads", feeds, expected).test_model(model) + + +def test_1d_input(model: ModelProto) -> None: + input_tensor = numpy.arange(1, 5, dtype="float32") + pads = numpy.array([1, 2], dtype="int") + constant_value = 0.0 + feeds = { + "input_tensor": input_tensor, + "pads": pads, + "constant_value": constant_value, + } + expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]]) + + TestCase("test_1d_input", feeds, expected).test_model(model) + + +def run_tests(model: ModelProto) -> None: + test_positive_pads(model) + test_1d_input(model) + # TODO: test_negative_pads + # TODO: support other modes: reflect, edge, wrap + + +### Helper Functions End ### + +import numpy +from onnx import TensorProto, numpy_helper +from onnx.helper import make_tensor_value_info + + +def get_initializers() -> list[TensorProto]: + pads = numpy_helper.from_array( + numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads" + ) + constant_value = numpy_helper.from_array( + numpy.array([0.0]).astype(numpy.float32), name="constant_value" + ) + + return [pads, constant_value] + + +def main() -> None: + name = "pad" + + inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])] + outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])] + initializers = get_initializers() + + build_test_save( + name=name, + inputs=inputs, + outputs=outputs, + initializers=initializers, + attributes={"mode": "constant"}, + ) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ffb4d28d47..751cbcb471 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,7 +8,7 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -105,6 +105,7 @@ pub enum Node { Matmul(MatmulNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), + Pad(PadNode), Range(RangeNode), Reshape(ReshapeNode), Resize(ResizeNode), @@ -150,6 +151,7 @@ macro_rules! match_all { Node::Matmul(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), + Node::Pad(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), Node::Resize(node) => $func(node), @@ -203,6 +205,7 @@ impl Node { Node::Matmul(_) => "matmul", Node::MaxPool1d(_) => "max_pool1d", Node::MaxPool2d(_) => "max_pool2d", + Node::Pad(_) => "pad", Node::Range(_) => "range", Node::Reshape(_) => "reshape", Node::Resize(_) => "resize", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 73c6a0201a..9d1fdce591 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; +pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; pub(crate) mod random_uniform; diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs new file mode 100644 index 0000000000..eabe77d7f1 --- /dev/null +++ b/crates/burn-import/src/burn/node/pad.rs @@ -0,0 +1,104 @@ +use std::str::FromStr; + +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct PadConfig { + pub pads: Vec, + pub constant_value: f32, +} + +#[derive(Debug, Clone, new)] +pub struct PadNode { + pub input: TensorType, + pub output: TensorType, + pub config: PadConfig, +} + +impl NodeCodegen for PadNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + let pads = self.config.pads.iter().map(|p| p.to_tokens()); + let constant_value_string = format!("{}_f32.elem()", self.config.constant_value); + let constant_value = TokenStream::from_str(&constant_value_string).unwrap(); + + quote! { + let #output = #input.pad((#(#pads),*), #constant_value); + } + } + fn into_node(self) -> Node { + Node::Pad(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::ElementConversion"); + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{pad::PadNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_pad() { + let mut graph = BurnGraph::::default(); + let config = PadConfig::new(vec![1, 2, 3, 4], -1.0); + graph.register(PadNode::new( + TensorType::new_float("input", 2), + TensorType::new_float("output", 2), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::tensor::ElementConversion; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.pad((1, 2, 3, 4), -1_f32.elem()); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1c85ab6e6a..def018f4ae 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::resize::ResizeMode; +use crate::burn::node::{pad::PadConfig, resize::ResizeMode}; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -745,6 +745,72 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { ) } +/// Create a PadConfig from the attributes of the node +pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads(node: &Node) -> Vec { + if node.inputs.len() < 2 { + panic!("Pad: must provide at least two inputs") + } + + let input_dim = match &node.inputs.first().unwrap().ty { + ArgType::Tensor(tensor) => tensor.dim, + _ => panic!("Pad: Only tensor input is valid"), + }; + + let pads: Vec = match &node.inputs[1].value { + Some(Data::Int64s(shape)) => shape + .iter() + .map(|&x| { + if x < 0 { + // TODO: support negative pads + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect(), + _ => panic!("Pad: pads data type must be int64"), + }; + + if pads.len() != input_dim * 2 { + panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); + } + // TODO: Burn's pad should support 1D tensor + if input_dim < 2 { + panic!("Pad: input tensor should be rank 2 or higher"); + } + + let left_index = input_dim - 1; + let top_index = input_dim - 2; + let right_index = pads.len() - 1; + let bottom_index = pads.len() - 2; + let index_list = [left_index, top_index, right_index, bottom_index]; + + for (index, &item) in pads.iter().enumerate() { + if !index_list.contains(&index) && item != 0 { + panic!("Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions"); + } + } + + let left = pads[left_index]; + let top = pads[top_index]; + let right = pads[right_index]; + let bottom = pads[bottom_index]; + vec![left, right, top, bottom] + } + fn get_constant_value(node: &Node) -> f32 { + // TODO: support int, boolean + match &node.inputs[2].value { + Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(), + _ => 0.0, + } + } + + let pads = get_pads(node); + let constant_value = get_constant_value(node); + + PadConfig::new(pads, constant_value) +} + /// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. /// /// # Arguments diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3b32ba9c4f..292d515e81 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -40,6 +40,7 @@ use crate::{ matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, + pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, @@ -63,7 +64,7 @@ use super::op_configuration::{ concat_config, conv1d_config, conv2d_config, conv3d_config, conv_transpose2d_config, conv_transpose3d_config, dropout_config, expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, - max_pool2d_config, reduce_max_config, reduce_mean_config, reduce_min_config, + max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config, }; @@ -324,6 +325,7 @@ impl ParsedOnnxGraph { NodeType::ConvTranspose3d => { graph.register(Self::conv_transpose3d_conversion::(node)) } + NodeType::Pad => graph.register(Self::pad_conversion(node)), NodeType::Pow => graph.register(Self::pow_conversion(node)), NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)), NodeType::Where => graph.register(Self::where_conversion(node)), @@ -1098,6 +1100,14 @@ impl ParsedOnnxGraph { BinaryNode::lower_equal(lhs, rhs, output) } + fn pad_conversion(node: Node) -> PadNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = pad_config(&node); + + PadNode::new(input, output, config) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = Type::from(node.inputs.first().unwrap()); let rhs = Type::from(node.inputs.get(1).unwrap()); diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index ed05b16422..0f9432cc0a 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -736,7 +736,7 @@ where ) } - /// Pad the tensor with the given value on the last two dimensions. + /// Pad the tensor of rank two or higher with the given value on the last two dimensions. /// /// # Arguments /// diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ae11ff1ac3..0fb45fd3c4 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -48,6 +48,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), + NodeType::Pad => same_as_input(node), NodeType::Greater => greater_update_outputs(node), NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node), NodeType::Less => less_update_outputs(node),