diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index a8b38348da..d5e5402408 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -194,30 +194,27 @@ fn concat_update_outputs(node: &mut Node) { node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } - fn reshape_update_outputs(node: &mut Node) { - assert_eq!(node.inputs.len(), 2); - - let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value { - shape - } else { - panic!("Reshape: int64s shape is expected per ONNX spec"); + let shape = match node.inputs.get(1) { + Some(input) => match &input.value { + Some(Data::Int64s(shape)) => Some(shape.clone()), + _ => panic!("Reshape: invalid input types"), + }, + None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()), }; - // The output dimension is the same as the shape length - let dim = shape.len(); - let elem_type = match node.inputs[0].ty.clone() { - ArgType::Tensor(tensor) => tensor.elem_type, - _ => panic!("Reshape: invalid input type"), + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Reshape: invalid output types"), }; - let shape = shape.iter().map(|&dim| dim as usize).collect(); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - dim, - shape: Some(shape), - }); + if let Some(shape) = shape { + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: shape.len(), + shape: None, // shape is calculated at runtime + ..output + }); + } } fn reduce_mean_update_outputs(node: &mut Node) { @@ -254,40 +251,24 @@ fn reduce_mean_update_outputs(node: &mut Node) { /// Update the output tensor dimension based on the "axes" attribute or the second input fn unsqueeze_update_output(node: &mut Node) { - let axes = if node.inputs.len() == 2 { - // get the values while making sure the types are correct - match &node.inputs[1].value { - Some(value) => match value { - Data::Int64s(axes) => Some(axes.clone()), - _ => panic!("Unsqueeze: invalid input types"), - }, - None => None, - } - } else { - node.attrs - .iter() - .find_map(|(key, value)| match key.as_str() { - "axes" => Some(value.clone().into_i64s()), - _ => None, - }) + let axes = match node.inputs.get(1) { + Some(input) => match &input.value { + Some(Data::Int64s(axes)) => Some(axes.clone()), + _ => panic!("Unsqueeze: invalid input types"), + }, + None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()), }; - // need output way up here to avoid borrowing issues let input = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Unsqueeze: invalid output types"), + ty => panic!("Unsqueeze: invalid output type ({ty:?})"), }; - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Unsqueeze: invalid output types"), - }; - - if axes.is_some() { + if let Some(axes) = axes { node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: input.dim + axes.unwrap().len(), + dim: input.dim + axes.len(), shape: None, // shape is calculated at runtime - ..output + ..input }); } }