Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/onnx argmax #1814

Merged
merged 8 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified crates/burn-import/onnx-tests/tests/argmax/argmax.onnx
Binary file not shown.
6 changes: 3 additions & 3 deletions crates/burn-import/onnx-tests/tests/argmax/argmax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/concat/concat.onnx
# used to generate model: onnx-tests/tests/argmax/argmax.onnx

import torch
import torch.nn as nn
Expand All @@ -11,8 +11,8 @@ def __init__(self, argmax_dim: int = 0):
self._argmax_dim = argmax_dim

def forward(self, x):
# Concatenate along the channel dimension
y = torch.argmax(input=x, dim=self._argmax_dim)
# Note: only keepdim=True is supported in burn
y = torch.argmax(input=x, dim=self._argmax_dim, keepdim=True)
return y

def main():
Expand Down
12 changes: 3 additions & 9 deletions crates/burn-import/src/burn/node/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ pub struct ArgMaxNode {
pub input: TensorType,
pub output: TensorType,
pub axis: usize,
pub select_last_index: usize,
pub keepdims: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
Expand All @@ -29,11 +27,9 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
//NOTE: select_last_index and keep_dims are not supported
let axis = self.axis.to_tokens();

//NOTE: are select_last_index and keep_dims supported?
let _select_last_index = self.select_last_index.to_tokens();
let _keepdims = self.keepdims.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

Expand All @@ -56,15 +52,13 @@ mod tests {
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};

#[test]
fn test_codegen_gather() {
fn test_codegen_argmax() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ArgMaxNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_float("tensor2", 2),
TensorType::new_int("tensor2", 2),
1,
0,
0,
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::ArgMax(_) => "argmax1",
Node::ArgMax(_) => "argmax",
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Expand Down
56 changes: 1 addition & 55 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
// NodeType::ArgMax => same_as_input(node),
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
Expand Down Expand Up @@ -375,65 +374,12 @@ fn argmax_update_outputs(node: &mut Node) {
_ => panic!("Only tensor input is valid"),
};

// let mut axis = match node.attrs.get("axes") {
// Some(value) => match &value {
// AttributeValue::Int64(v) => *v,
// _ => 0,
// },
// None => 0,
// };

// Note: argmax in burn does not support keepdims=false
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: tensor.dim,
shape: tensor.shape.clone(),
elem_type: ElementType::Int64,
});

// // burn will always assume we keep dims
// let mut keepdims = match node.attrs.get("keepdims") {
// Some(value) => match value {
// AttributeValue::Int64(v) => *v == 1,
// _ => panic!("Only int64 keepdims is valid"),
// },
// None => false,
// };

// if keepdims {
// node.outputs[0].ty = ArgType::Tensor(TensorType {
// dim: tensor.dim.clone(),
// shape: tensor.shape.clone(),
// elem_type: ElementType::Int64,
// });
// }else {

// if axis < 0 {
// axis = tensor.dim as i64 + axis;
// }

// if (axis < 0) | (axis >= tensor.dim as i64) {
// panic!("axis {:?} is outside of legal range [0,{:?}]", axis, tensor.dim);
// }

// // Note -> seems like keepdim is always on??

// let output_shape: Option<Vec<usize>>;
// match tensor.shape {
// Some(shape) => {
// let mut s = shape.clone();
// s.remove(axis as usize);
// output_shape = Some(s);
// }
// None => {
// output_shape = None;
// }
// }

// node.outputs[0].ty = ArgType::Tensor(TensorType {
// dim: tensor.dim.clone() - 1,
// shape: output_shape.clone(),
// elem_type: ElementType::Int64,
// });
// }
}

/// Update the output tensor dimension
Expand Down
20 changes: 15 additions & 5 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,9 @@ pub fn softmax_config(node: &Node) -> usize {
}

/// Create argmax config from the attributes of the node
pub fn argmax_config(node: &Node) -> (usize, usize, usize) {
pub fn argmax_config(node: &Node) -> usize {
let mut axis: i64 = 0;
let mut select_last_index: i64 = 0;
let mut keepdims: i64 = 0;
let mut keepdims: i64 = 1;

// check if the node has only one input
if node.inputs.len() != 1 {
Expand All @@ -492,18 +491,29 @@ pub fn argmax_config(node: &Node) -> (usize, usize, usize) {
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"select_last_index" => select_last_index = value.clone().into_i64(),
"select_last_index" => log::warn!(
"select_last_index param for argmax is ignored in burn (got {:?})",
value
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still capture the select_last_index value here, but only warn if it is 1 (because the default implementation pretty much everywhere including Burn is to return the first max value, not last).

"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}

// Not all params for argmax are supported in burn.
if keepdims != 1 {
panic!(
"Only keepdims=1 is supported for argmax in burn (got {:?})",
keepdims
);
}

// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.dim as i64;
}

(axis as usize, select_last_index as usize, keepdims as usize)
axis as usize
}

/// Create concat config from the attributes of the node
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,9 @@ impl OnnxGraph {
fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let (axis, select_last_index, keepdims) = argmax_config(&node);
let axis = argmax_config(&node);

ArgMaxNode::new(input, output, axis, select_last_index, keepdims)
ArgMaxNode::new(input, output, axis)
}

fn concat_conversion(node: Node) -> ConcatNode {
Expand Down
Loading