Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/onnx argmax #1814

Merged
merged 8 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ represent the corresponding Burn Op.
| [Acosh][3] | ❌ | ❌ |
| [Add][4] | ✅ | ✅ |
| [And][5] | ❌ | ❌ |
| [ArgMax][6] | | ✅ |
| [ArgMax][6] | | ✅ |
| [ArgMin][7] | ❌ | ❌ |
| [Asin][8] | ❌ | ❌ |
| [Asinh][9] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/argmax/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/argmax/argmax.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self, argmax_dim: int = 0):
super(Model, self).__init__()
self._argmax_dim = argmax_dim

def forward(self, x):
# Note: only keepdim=True is supported in burn
y = torch.argmax(input=x, dim=self._argmax_dim, keepdim=True)
return y

def main():

# Export to onnx
model = Model(1)
model.eval()
device = torch.device("cpu")
onnx_name = "argmax.onnx"
dummy_input = torch.randn((3, 4), device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.randn((2, 3), device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)

print("Test output data shape: {}".format(output.shape))



if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ macro_rules! include_models {
include_models!(
add_int,
add,
argmax,
avg_pool2d,
avg_pool1d,
batch_norm,
Expand Down Expand Up @@ -368,6 +369,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn argmax() {
// Initialize the model with weights (loaded from the exported file)
let model: argmax::Model<Backend> = argmax::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let output = model.forward(input);
let expected = Data::from([[2], [2]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
Expand Down
102 changes: 102 additions & 0 deletions crates/burn-import/src/burn/node/argmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorKind, TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct ArgMaxNode {
pub input: TensorType,
pub output: TensorType,
pub axis: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
fn output_types(&self) -> Vec<Type> {
let mut output = self.output.clone();
output.kind = TensorKind::Int;
vec![Type::Tensor(output)]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
//NOTE: select_last_index and keep_dims are not supported
let axis = self.axis.to_tokens();

let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

quote! {
let #output = #input.argmax(#axis);
}
}

fn into_node(self) -> super::Node<PS> {
Node::ArgMax(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};

#[test]
fn test_codegen_argmax() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ArgMaxNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
1,
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>
) -> Tensor<B, 2, Int> {
let tensor2 = tensor1.argmax(1);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
20 changes: 12 additions & 8 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::{
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -76,6 +77,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
ArgMax(ArgMaxNode),
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Expand Down Expand Up @@ -108,6 +110,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::ArgMax(node) => $func(node),
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Expand Down Expand Up @@ -150,6 +153,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::ArgMax(_) => "argmax",
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod argmax;
pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
Expand Down
20 changes: 20 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
Expand Down Expand Up @@ -362,6 +363,25 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

fn argmax_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
}

let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Note: argmax in burn does not support keepdims=false
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: tensor.dim,
shape: tensor.shape.clone(),
elem_type: ElementType::Int64,
});
}

/// Update the output tensor dimension
fn squeeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
48 changes: 48 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,54 @@ pub fn softmax_config(node: &Node) -> usize {
axis as usize
}

/// Create argmax config from the attributes of the node
pub fn argmax_config(node: &Node) -> usize {
let mut axis: i64 = 0;
let mut keepdims: i64 = 1;

// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"Argmax: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}

// extract the shape of the input tensor
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"select_last_index" => log::warn!(
"select_last_index param for argmax is ignored in burn (got {:?})",
value
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still capture the select_last_index value here, but only warn if it is 1 (because the default implementation pretty much everywhere including Burn is to return the first max value, not last).

"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}

// Not all params for argmax are supported in burn.
if keepdims != 1 {
panic!(
"Only keepdims=1 is supported for argmax in burn (got {:?})",
keepdims
);
}

// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.dim as i64;
}

axis as usize
}

/// Create concat config from the attributes of the node
pub fn concat_config(node: &Node) -> usize {
// the axis is the last dimension (Default: 1 per ONNX spec)
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
burn::{
graph::BurnGraph,
node::{
argmax::ArgMaxNode,
avg_pool1d::AvgPool1dNode,
avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode,
Expand Down Expand Up @@ -235,6 +236,7 @@ impl OnnxGraph {
for node in self.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
NodeType::Sub => graph.register(Self::sub_conversion(node)),
NodeType::Mul => graph.register(Self::mul_conversion(node)),
NodeType::Div => graph.register(Self::div_conversion(node)),
Expand Down Expand Up @@ -681,6 +683,14 @@ impl OnnxGraph {
UnaryNode::tanh(input, output)
}

fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let axis = argmax_config(&node);

ArgMaxNode::new(input, output, axis)
}

fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node
.inputs
Expand Down
Loading