Skip to content

Commit

Permalink
Fix checks_channels_div_groups condition and ONNX conv import with gr…
Browse files Browse the repository at this point in the history
…oups (#2051)

* Fix checks_channels_div_groups condition

* Fix conv channels config w/ groups
  • Loading branch information
laggui authored Jul 22, 2024
1 parent 0bbc1ed commit 4c73532
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/conv/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
let channels_in_div_by_group = channels_in % groups == 0;
let channels_out_div_by_group = channels_out % groups == 0;

if !channels_in_div_by_group && !channels_out_div_by_group {
if !channels_in_div_by_group || !channels_out_div_by_group {
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-core/src/nn/conv/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ mod tests {
assert_eq!(config.initializer, init);
}

#[test]
#[should_panic = "Both channels must be divisible by the number of groups."]
fn channels_with_groups_is_invalid() {
let device = Default::default();
let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
let _ = config.init::<TestBackend>(&device);
}

#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);
Expand Down
56 changes: 28 additions & 28 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
let mut strides = vec![1];
let mut pads = vec![0, 0];
let mut dilations = vec![1];
let mut group: i64 = 1;
let mut group: usize = 1;

// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
Expand All @@ -28,28 +28,28 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels_in = shape[1];
let channels_out = shape[0];

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels_in = shape[1] * group;
let channels_out = shape[0];

let padding = padding_config_1d(&pads);

Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize)
.with_stride(strides[0] as usize)
.with_dilation(dilations[0] as usize)
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
Expand All @@ -60,7 +60,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
let mut strides = vec![1, 1];
let mut pads = vec![0, 0, 0, 0];
let mut dilations = vec![1, 1];
let mut group: i64 = 1;
let mut group: usize = 1;

// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
Expand All @@ -71,21 +71,21 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1] * group, shape[0]];

let padding = padding_config_2d(&pads);

Conv2dConfig::new(
Expand All @@ -94,7 +94,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
)
.with_stride([strides[0] as usize, strides[1] as usize])
.with_dilation([dilations[0] as usize, dilations[1] as usize])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
Expand All @@ -105,7 +105,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
let mut strides = vec![1, 1, 1];
let mut pads = vec![0, 0, 0, 0, 0, 0];
let mut dilations = vec![1, 1, 1];
let mut group: i64 = 1;
let mut group: usize = 1;

// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
Expand All @@ -116,21 +116,21 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1] * group, shape[0]];

let padding = padding_config_3d(&pads);

Conv3dConfig::new(
Expand All @@ -151,7 +151,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
dilations[1] as usize,
dilations[2] as usize,
])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
Expand Down Expand Up @@ -228,7 +228,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let group = attrs
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1);
.unwrap_or(1) as usize;

// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
Expand All @@ -247,7 +247,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
let channels: [usize; 2] = [shape[1] * group, shape[0]];

ConvTranspose2dConfig::new(
channels,
Expand All @@ -256,7 +256,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
.with_stride([stride[0] as usize, stride[1] as usize])
.with_padding([pads[0] as usize, pads[1] as usize])
.with_dilation([dilations[0] as usize, dilations[1] as usize])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
}
pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
Expand All @@ -280,7 +280,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
let group = attrs
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1);
.unwrap_or(1) as usize;

// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
Expand All @@ -299,7 +299,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {

// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
let channels: [usize; 2] = [shape[1] * group, shape[0]];

ConvTranspose3dConfig::new(
channels,
Expand All @@ -316,7 +316,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
dilations[1] as usize,
dilations[2] as usize,
])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
}

Expand Down

0 comments on commit 4c73532

Please sign in to comment.