Skip to content

Commit

Permalink
Fix reshape bug (support for opset version 1) (#1667)
Browse files Browse the repository at this point in the history
* Make reshape op version 1

* Refactor per PR feedback
  • Loading branch information
antimora authored and syl20bnr committed Apr 26, 2024
1 parent 8070252 commit 0f49205
Showing 1 changed file with 26 additions and 45 deletions.
71 changes: 26 additions & 45 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,27 @@ fn concat_update_outputs(node: &mut Node) {

node.outputs[0].ty = ArgType::Tensor(tensor.clone());
}

fn reshape_update_outputs(node: &mut Node) {
assert_eq!(node.inputs.len(), 2);

let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value {
shape
} else {
panic!("Reshape: int64s shape is expected per ONNX spec");
let shape = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(shape)) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
};

// The output dimension is the same as the shape length
let dim = shape.len();
let elem_type = match node.inputs[0].ty.clone() {
ArgType::Tensor(tensor) => tensor.elem_type,
_ => panic!("Reshape: invalid input type"),
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Reshape: invalid output types"),
};

let shape = shape.iter().map(|&dim| dim as usize).collect();

node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type,
dim,
shape: Some(shape),
});
if let Some(shape) = shape {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: shape.len(),
shape: None, // shape is calculated at runtime
..output
});
}
}

fn reduce_mean_update_outputs(node: &mut Node) {
Expand Down Expand Up @@ -254,40 +251,24 @@ fn reduce_mean_update_outputs(node: &mut Node) {

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
// get the values while making sure the types are correct
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => None,
}
} else {
node.attrs
.iter()
.find_map(|(key, value)| match key.as_str() {
"axes" => Some(value.clone().into_i64s()),
_ => None,
})
let axes = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(axes)) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
};

// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
};

let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};

if axes.is_some() {
if let Some(axes) = axes {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.unwrap().len(),
dim: input.dim + axes.len(),
shape: None, // shape is calculated at runtime
..output
..input
});
}
}
Expand Down

0 comments on commit 0f49205

Please sign in to comment.