Skip to content

Commit

Permalink
Convert compatible prelu weights to rank 1 (#2054)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jul 23, 2024
1 parent 4c73532 commit 53c77ae
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,9 +974,19 @@ impl ParsedOnnxGraph {
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let mut weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = PReluConfig::new();
let name = &node.name;

if weight.shape.len() > 1 {
if weight.shape[1..].iter().product::<usize>() == 1 {
// Burn accepts rank 1 alpha weight
weight.shape = weight.shape[..1].to_vec();
} else {
panic!("Invalid PRelu weight with shape {:?}", weight.shape);
}
}

PReluNode::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
Expand Down

0 comments on commit 53c77ae

Please sign in to comment.