Skip to content

Commit

Permalink
Bug/Remove Squeeze Panic for Multiple Dimensions (tracel-ai#2035)
Browse files Browse the repository at this point in the history
* Remove panic for squeeze when more than one axis is specified

* Remove extra Model()

* Change script to squeeze all singleton dimensions

* Revert change since burn requires axes to be specified

* Fix input tensor

* Try updating ONNX files again

* Add script for testing multiple axes along with new ONNX file

* Update squeeze.py comments

* Add squeeze_multiple model to tests

* Fix dim_inference
  • Loading branch information
agelas authored Jul 22, 2024
1 parent 19cd67a commit 0bbc1ed
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ fn main() {
.input("tests/mask_where/mask_where.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/squeeze/squeeze_multiple.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ include_models!(
unsqueeze_opset11,
squeeze_opset16,
squeeze_opset13,
squeeze_multiple,
random_uniform,
random_normal,
constant_of_shape,
Expand Down Expand Up @@ -1642,6 +1643,17 @@ mod tests {
assert_eq!(expected_shape, output.shape());
}

#[test]
fn squeeze_multiple() {
let device = Default::default();
let model = squeeze_multiple::Model::<Backend>::new(&device);
let input_shape = Shape::from([3, 4, 1, 5, 1]);
let expected_shape = Shape::from([3, 4, 5]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn random_uniform() {
let device = Default::default();
Expand Down
24 changes: 19 additions & 5 deletions crates/burn-import/onnx-tests/tests/squeeze/squeeze.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#!/usr/bin/env python3

# used to generate model: squeeze.onnx
# used to generate models: squeeze_opset13.onnx,
# squeeze_opset16.onnx, and squeeze_multiple.onnx

import torch
import onnx
import torch.nn as nn

from onnx import helper, TensorProto

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.axis = 2
self.dims = 2

def forward(self, x):
x = torch.squeeze(x, self.axis)
x = torch.squeeze(x, self.dims)
return x


Expand All @@ -28,7 +30,6 @@ def main():
device = torch.device("cpu")

test_input = torch.randn(3, 4, 1, 5, device=device)
model = Model()

# Export to ONNX
torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16)
Expand All @@ -43,6 +44,19 @@ def main():
print(f"Test output data shape: {output.shape}")
print(f"Test output: {output}")

# Test for squeezing multiple dimensions
test_input_ms = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 1, 5, 1])
output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])
squeeze = helper.make_node(op_type="Squeeze", inputs=["input", "axes"], outputs=["output"], name="SqueezeOp")
axes = helper.make_tensor("axes", TensorProto.INT64, dims=[2], vals=[2, 4])
graph = helper.make_graph([squeeze], "SqueezeMultiple", [test_input_ms], [output], [axes])
opset = helper.make_opsetid("", 13)
m = helper.make_model(graph, opset_imports=[opset])

onnx.checker.check_model(m, full_check=True)
onnx.save(m, "squeeze_multiple.onnx")

print(f"Finished exporting model with multiple squeeze axes specified to 13")

if __name__ == "__main__":
main()
Binary file not shown.
9 changes: 3 additions & 6 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,25 +459,22 @@ fn squeeze_update_output(node: &mut Node) {

if axes.is_none() {
panic!("Squeeze must specify an axis");
} else if axes.as_ref().unwrap().len() > 1 {
panic!(
"Squeeze must specify only 1 axis, found {:?}",
axes.as_ref().unwrap().len()
);
}

let input_dim = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Squeeze: invalid input type"),
};

let new_dim = input_dim - axes.unwrap().len();

let output_elem = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.elem_type.clone(),
_ => panic!("Squeeze: invalid output type"),
};

node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input_dim - 1,
dim: new_dim,
shape: None, // shape is tracked and calculated at runtime
elem_type: output_elem,
});
Expand Down

0 comments on commit 0bbc1ed

Please sign in to comment.