From 0bbc1ed30f017a0050089e260bb89fbafbe7f121 Mon Sep 17 00:00:00 2001 From: Mathias Insley <55096933+agelas@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:13:07 -0700 Subject: [PATCH] Bug/Remove Squeeze Panic for Multiple Dimensions (#2035) * Remove panic for squeeze when more than one axis is specified * Remove extra Model() * Change script to squeeze all singleton dimensions * Revert change since burn requires axes to be specified * Fix input tensor * Try updating ONNX files again * Add script for testing multiple axes along with new ONNX file * Update squeeze.py comments * Add squeeze_multiple model to tests * Fix dim_inference --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 12 +++++++++ .../onnx-tests/tests/squeeze/squeeze.py | 24 ++++++++++++++---- .../tests/squeeze/squeeze_multiple.onnx | Bin 0 -> 154 bytes crates/onnx-ir/src/dim_inference.rs | 9 +++---- 5 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 497400369f..d221d40a71 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -83,6 +83,7 @@ fn main() { .input("tests/mask_where/mask_where.onnx") .input("tests/squeeze/squeeze_opset16.onnx") .input("tests/squeeze/squeeze_opset13.onnx") + .input("tests/squeeze/squeeze_multiple.onnx") .input("tests/random_uniform/random_uniform.onnx") .input("tests/random_normal/random_normal.onnx") .input("tests/constant_of_shape/constant_of_shape.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index d5cc470f06..577f12ac85 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -93,6 +93,7 @@ include_models!( unsqueeze_opset11, squeeze_opset16, squeeze_opset13, + squeeze_multiple, random_uniform, random_normal, constant_of_shape, @@ -1642,6 +1643,17 @@ mod tests { assert_eq!(expected_shape, output.shape()); } + #[test] + fn squeeze_multiple() { + let device = Default::default(); + let model = squeeze_multiple::Model::::new(&device); + let input_shape = Shape::from([3, 4, 1, 5, 1]); + let expected_shape = Shape::from([3, 4, 5]); + let input = Tensor::ones(input_shape, &device); + let output = model.forward(input); + assert_eq!(expected_shape, output.shape()); + } + #[test] fn random_uniform() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py index 6281587801..eb2d3723b6 100644 --- a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py +++ b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py @@ -1,18 +1,20 @@ #!/usr/bin/env python3 -# used to generate model: squeeze.onnx +# used to generate models: squeeze_opset13.onnx, +# squeeze_opset16.onnx, and squeeze_multiple.onnx import torch +import onnx import torch.nn as nn - +from onnx import helper, TensorProto class Model(nn.Module): def __init__(self): super(Model, self).__init__() - self.axis = 2 + self.dims = 2 def forward(self, x): - x = torch.squeeze(x, self.axis) + x = torch.squeeze(x, self.dims) return x @@ -28,7 +30,6 @@ def main(): device = torch.device("cpu") test_input = torch.randn(3, 4, 1, 5, device=device) - model = Model() # Export to ONNX torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16) @@ -43,6 +44,19 @@ def main(): print(f"Test output data shape: {output.shape}") print(f"Test output: {output}") + # Test for squeezing multiple dimensions + test_input_ms = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 1, 5, 1]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5]) + squeeze = helper.make_node(op_type="Squeeze", inputs=["input", "axes"], outputs=["output"], name="SqueezeOp") + axes = helper.make_tensor("axes", TensorProto.INT64, dims=[2], vals=[2, 4]) + graph = helper.make_graph([squeeze], "SqueezeMultiple", [test_input_ms], [output], [axes]) + opset = helper.make_opsetid("", 13) + m = helper.make_model(graph, opset_imports=[opset]) + + onnx.checker.check_model(m, full_check=True) + onnx.save(m, "squeeze_multiple.onnx") + + print(f"Finished exporting model with multiple squeeze axes specified to 13") if __name__ == "__main__": main() diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx b/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx new file mode 100644 index 0000000000000000000000000000000000000000..46760e4469a6180e5b44b1612860614df5e39eca GIT binary patch literal 154 zcmdz%qu7@;bKXwNG%p(%P%bf@}xL}3rkZ|t5W?7l-Qw6A$};sw=}0D zvmhr`i;sgzfZd9TiNy)5IZ7F7mXH*e1P7y#2p1CvGZ3=?F(VMOg6Jd}E~p(sLRUI5tIABq3~ literal 0 HcmV?d00001 diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ae11ff1ac3..d032569ed2 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -459,11 +459,6 @@ fn squeeze_update_output(node: &mut Node) { if axes.is_none() { panic!("Squeeze must specify an axis"); - } else if axes.as_ref().unwrap().len() > 1 { - panic!( - "Squeeze must specify only 1 axis, found {:?}", - axes.as_ref().unwrap().len() - ); } let input_dim = match &node.inputs[0].ty { @@ -471,13 +466,15 @@ fn squeeze_update_output(node: &mut Node) { _ => panic!("Squeeze: invalid input type"), }; + let new_dim = input_dim - axes.unwrap().len(); + let output_elem = match &node.outputs[0].ty { ArgType::Tensor(tensor) => tensor.elem_type.clone(), _ => panic!("Squeeze: invalid output type"), }; node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: input_dim - 1, + dim: new_dim, shape: None, // shape is tracked and calculated at runtime elem_type: output_elem, });