Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 3, 2024
1 parent 8d226c7 commit e360bdf
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 18 deletions.
1 change: 0 additions & 1 deletion crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PRelu");
imports.register("burn::nn::prelu::PRelu");
imports.register("burn::nn::prelu::PReluConfig");
}
Expand Down
16 changes: 0 additions & 16 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
.with_padding(padding)
.with_dilation([dilations[0] as usize, dilations[1] as usize])
}
pub fn prelu_config(curr: &Node) -> PReluConfig {
let mut alpha = 0.01;
let mut num_parameters = 0;
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"alpha" => alpha = value.clone().into_f32(),
"num_parameters" => num_parameters = value.clone().into_i32(),
_ => {}
}
}

PReluConfig::new()
.with_num_parameters(num_parameters as usize)
.with_alpha(alpha as f64)
}

pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let mut attrs = curr.attrs.clone();
let kernel_shape = attrs
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
};

use burn::{
nn::PReluConfig,
record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings},
tensor::{DataSerialize, Element},
};
Expand Down Expand Up @@ -701,7 +702,7 @@ impl OnnxGraph {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = prelu_config(&node);
let config = PReluConfig::new();
let name = &node.name;
PReluNode::<PS>::new(name, input, output, weight, config)
}
Expand Down

0 comments on commit e360bdf

Please sign in to comment.