Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/onnx argmax #1814

Merged
merged 8 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ represent the corresponding Burn Op.
| [Acosh][3] | ❌ | ❌ |
| [Add][4] | ✅ | ✅ |
| [And][5] | ❌ | ❌ |
| [ArgMax][6] | | ✅ |
| [ArgMax][6] | | ✅ |
| [ArgMin][7] | ❌ | ❌ |
| [Asin][8] | ❌ | ❌ |
| [Asinh][9] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/argmax/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/concat/concat.onnx
laggui marked this conversation as resolved.
Show resolved Hide resolved

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self, argmax_dim: int = 0):
super(Model, self).__init__()
self._argmax_dim = argmax_dim

def forward(self, x):
# Concatenate along the channel dimension
laggui marked this conversation as resolved.
Show resolved Hide resolved
y = torch.argmax(input=x, dim=self._argmax_dim)
return y

def main():

# Export to onnx
model = Model(1)
model.eval()
device = torch.device("cpu")
onnx_name = "argmax.onnx"
dummy_input = torch.randn((3, 4), device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.randn((2, 3), device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)

print("Test output data shape: {}".format(output.shape))



if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ macro_rules! include_models {
include_models!(
add_int,
add,
argmax,
avg_pool2d,
avg_pool1d,
batch_norm,
Expand Down Expand Up @@ -366,6 +367,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn argmax() {
// Initialize the model with weights (loaded from the exported file)
let model: argmax::Model<Backend> = argmax::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let output = model.forward(input);
let expected = Data::from([[2], [2]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
Expand Down
108 changes: 108 additions & 0 deletions crates/burn-import/src/burn/node/argmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorKind, TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct ArgMaxNode {
pub input: TensorType,
pub output: TensorType,
pub axis: usize,
pub select_last_index: usize,
pub keepdims: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
fn output_types(&self) -> Vec<Type> {
let mut output = self.output.clone();
output.kind = TensorKind::Int;
vec![Type::Tensor(output)]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let axis = self.axis.to_tokens();

//NOTE: are select_last_index and keep_dims supported?
let _select_last_index = self.select_last_index.to_tokens();
let _keepdims = self.keepdims.to_tokens();
laggui marked this conversation as resolved.
Show resolved Hide resolved
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

quote! {
let #output = #input.argmax(#axis);
}
}

fn into_node(self) -> super::Node<PS> {
Node::ArgMax(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};

#[test]
fn test_codegen_gather() {
laggui marked this conversation as resolved.
Show resolved Hide resolved
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ArgMaxNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_float("tensor2", 2),
laggui marked this conversation as resolved.
Show resolved Hide resolved
1,
0,
0,
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>
) -> Tensor<B, 2, Int> {
let tensor2 = tensor1.argmax(1);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
18 changes: 11 additions & 7 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::{
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -75,6 +76,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
ArgMax(ArgMaxNode),
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Expand Down Expand Up @@ -105,6 +107,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::ArgMax(node) => $func(node),
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Expand Down Expand Up @@ -145,6 +148,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::ArgMax(_) => "argmax1",
laggui marked this conversation as resolved.
Show resolved Hide resolved
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod argmax;
pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
Expand Down
74 changes: 74 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
// NodeType::ArgMax => same_as_input(node),
laggui marked this conversation as resolved.
Show resolved Hide resolved
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
Expand Down Expand Up @@ -321,6 +323,78 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

fn argmax_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
}

let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// let mut axis = match node.attrs.get("axes") {
// Some(value) => match &value {
// AttributeValue::Int64(v) => *v,
// _ => 0,
// },
// None => 0,
// };

node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: tensor.dim,
shape: tensor.shape.clone(),
elem_type: ElementType::Int64,
});

// // burn will always assume we keep dims
// let mut keepdims = match node.attrs.get("keepdims") {
// Some(value) => match value {
// AttributeValue::Int64(v) => *v == 1,
// _ => panic!("Only int64 keepdims is valid"),
// },
// None => false,
// };

// if keepdims {
// node.outputs[0].ty = ArgType::Tensor(TensorType {
// dim: tensor.dim.clone(),
// shape: tensor.shape.clone(),
// elem_type: ElementType::Int64,
// });
// }else {

// if axis < 0 {
// axis = tensor.dim as i64 + axis;
// }

// if (axis < 0) | (axis >= tensor.dim as i64) {
// panic!("axis {:?} is outside of legal range [0,{:?}]", axis, tensor.dim);
// }

// // Note -> seems like keepdim is always on??

// let output_shape: Option<Vec<usize>>;
// match tensor.shape {
// Some(shape) => {
// let mut s = shape.clone();
// s.remove(axis as usize);
// output_shape = Some(s);
// }
// None => {
// output_shape = None;
// }
// }

// node.outputs[0].ty = ArgType::Tensor(TensorType {
// dim: tensor.dim.clone() - 1,
// shape: output_shape.clone(),
// elem_type: ElementType::Int64,
// });
// }
laggui marked this conversation as resolved.
Show resolved Hide resolved
}

/// Update the output tensor dimension
fn squeeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
38 changes: 38 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,44 @@ pub fn softmax_config(node: &Node) -> usize {
axis as usize
}

/// Create argmax config from the attributes of the node
pub fn argmax_config(node: &Node) -> (usize, usize, usize) {
let mut axis: i64 = 0;
let mut select_last_index: i64 = 0;
let mut keepdims: i64 = 0;

// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"Argmax: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}

// extract the shape of the input tensor
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"select_last_index" => select_last_index = value.clone().into_i64(),
"keepdims" => keepdims = value.clone().into_i64(),
laggui marked this conversation as resolved.
Show resolved Hide resolved
_ => {}
}
}

// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.dim as i64;
}

(axis as usize, select_last_index as usize, keepdims as usize)
}

/// Create concat config from the attributes of the node
pub fn concat_config(node: &Node) -> usize {
// the axis is the last dimension (Default: 1 per ONNX spec)
Expand Down
Loading
Loading