diff --git a/crates/burn-core/src/nn/conv/checks.rs b/crates/burn-core/src/nn/conv/checks.rs index 390eb1feda..cd346163ad 100644 --- a/crates/burn-core/src/nn/conv/checks.rs +++ b/crates/burn-core/src/nn/conv/checks.rs @@ -2,7 +2,7 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize let channels_in_div_by_group = channels_in % groups == 0; let channels_out_div_by_group = channels_out % groups == 0; - if !channels_in_div_by_group && !channels_out_div_by_group { + if !channels_in_div_by_group || !channels_out_div_by_group { panic!( "Both channels must be divisible by the number of groups. Got \ channels_in={channels_in}, channels_out={channels_out}, groups={groups}" diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index bf31fd9661..73be36d357 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -220,6 +220,14 @@ mod tests { assert_eq!(config.initializer, init); } + #[test] + #[should_panic = "Both channels must be divisible by the number of groups."] + fn channels_with_groups_is_invalid() { + let device = Default::default(); + let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv2dConfig::new([5, 1], [5, 5]); diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1c85ab6e6a..f69dfe4c26 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -16,7 +16,7 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig { let mut strides = vec![1]; let mut pads = vec![0, 0]; let mut dilations = vec![1]; - let mut group: i64 = 1; + let mut group: usize = 1; // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { @@ -28,28 +28,28 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels_in = shape[1]; - let channels_out = shape[0]; - for (key, value) in curr.attrs.iter() { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), + "group" => group = value.clone().into_i64() as usize, _ => {} } } + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels_in = shape[1] * group; + let channels_out = shape[0]; + let padding = padding_config_1d(&pads); Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) .with_stride(strides[0] as usize) .with_dilation(dilations[0] as usize) - .with_groups(group as usize) + .with_groups(group) .with_bias(bias) .with_padding(padding) } @@ -60,7 +60,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { let mut strides = vec![1, 1]; let mut pads = vec![0, 0, 0, 0]; let mut dilations = vec![1, 1]; - let mut group: i64 = 1; + let mut group: usize = 1; // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { @@ -71,21 +71,21 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; - for (key, value) in curr.attrs.iter() { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), + "group" => group = value.clone().into_i64() as usize, _ => {} } } + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels: [usize; 2] = [shape[1] * group, shape[0]]; + let padding = padding_config_2d(&pads); Conv2dConfig::new( @@ -94,7 +94,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { ) .with_stride([strides[0] as usize, strides[1] as usize]) .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group as usize) + .with_groups(group) .with_bias(bias) .with_padding(padding) } @@ -105,7 +105,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { let mut strides = vec![1, 1, 1]; let mut pads = vec![0, 0, 0, 0, 0, 0]; let mut dilations = vec![1, 1, 1]; - let mut group: i64 = 1; + let mut group: usize = 1; // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty { @@ -116,21 +116,21 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { // check if the bias is present let bias = curr.inputs.len() == 3; - // the channels are inverted in the weight tensor - let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; - for (key, value) in curr.attrs.iter() { match key.as_str() { "kernel_shape" => kernel_shape = value.clone().into_i64s(), "strides" => strides = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64(), + "group" => group = value.clone().into_i64() as usize, _ => {} } } + // the channels are inverted in the weight tensor + let shape = weight.shape.clone().unwrap(); + let channels: [usize; 2] = [shape[1] * group, shape[0]]; + let padding = padding_config_3d(&pads); Conv3dConfig::new( @@ -151,7 +151,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { dilations[1] as usize, dilations[2] as usize, ]) - .with_groups(group as usize) + .with_groups(group) .with_bias(bias) .with_padding(padding) } @@ -228,7 +228,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let group = attrs .remove("group") .map(AttributeValue::into_i64) - .unwrap_or(1); + .unwrap_or(1) as usize; // Trick with remove + empty check is simplest way to not forget some attribute for runtime: if !attrs.is_empty() { @@ -247,7 +247,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { // the channels are inverted in the weight tensor let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; + let channels: [usize; 2] = [shape[1] * group, shape[0]]; ConvTranspose2dConfig::new( channels, @@ -256,7 +256,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { .with_stride([stride[0] as usize, stride[1] as usize]) .with_padding([pads[0] as usize, pads[1] as usize]) .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group as usize) + .with_groups(group) .with_bias(bias) } pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { @@ -280,7 +280,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { let group = attrs .remove("group") .map(AttributeValue::into_i64) - .unwrap_or(1); + .unwrap_or(1) as usize; // Trick with remove + empty check is simplest way to not forget some attribute for runtime: if !attrs.is_empty() { @@ -299,7 +299,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { // the channels are inverted in the weight tensor let shape = weight.shape.clone().unwrap(); - let channels: [usize; 2] = [shape[1], shape[0]]; + let channels: [usize; 2] = [shape[1] * group, shape[0]]; ConvTranspose3dConfig::new( channels, @@ -316,7 +316,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { dilations[1] as usize, dilations[2] as usize, ]) - .with_groups(group as usize) + .with_groups(group) .with_bias(bias) }