Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx op topk #2305

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ represent the corresponding Burn Op.
| [TfIdfVectorizer][183] | ❌ | ❌ |
| [ThresholdedRelu][184] | ❌ | ❌ |
| [Tile][185] | ✅ | ✅ |
| [TopK][186] | | ✅ |
| [TopK][186] | | ✅ |
| [Transpose][187] | ✅ | ✅ |
| [Trilu][188] | ❌ | ✅ |
| [Unique][189] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ fn main() {
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/top_k/top_k.onnx")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks the CI caught something I missed! This file doesn't exist anymore with your changes :)

.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
20 changes: 19 additions & 1 deletion crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ include_models!(
sum_int,
tanh,
tile,
top_k,
transpose,
unsqueeze,
unsqueeze_opset11,
Expand All @@ -128,7 +129,7 @@ mod tests {

use super::*;

use burn::tensor::{Bool, Int, Shape, Tensor, TensorData};
use burn::tensor::{cast::ToElement, Bool, Int, Shape, Tensor, TensorData};

use float_cmp::ApproxEq;

Expand Down Expand Up @@ -2125,4 +2126,21 @@ mod tests {
assert!(i_output.equal(i_expected).all().into_scalar());
assert!(b_output.equal(b_expected).all().into_scalar());
}

#[test]
fn top_k() {
// Initialize the model
let device = Default::default();
let model = top_k::Model::<Backend>::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
&device,
);
let (values_tensor, _indices_tensor) = model.forward(input);
// data from pyTorch
let expected = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]);
values_tensor.to_data().assert_eq(&expected, true);
oojo12 marked this conversation as resolved.
Show resolved Hide resolved
}
}
Binary file added crates/burn-import/onnx-tests/tests/top_k/top_k.onnx
Binary file not shown.
57 changes: 57 additions & 0 deletions crates/burn-import/onnx-tests/tests/top_k/top_k.py
oojo12 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import onnx
from onnx import helper, TensorProto

# Define the input tensor
X = np.array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=np.float32)

# Define the value of K
k = 3
K = np.array([k], dtype=np.int64)
axis = 1
new_dims = [X.shape[0], k]

input_tensors = [
helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape),
#helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)
]

output_tensors = [
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims),
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims)
]

# Create the TopK node
node = helper.make_node(
'TopK',
inputs=['X'],# 'K'],
outputs=['Values', 'Indices'],
axis=axis, # Axis along which to find the top K elements
#largest=-1,
k=k
)

# Create the graph
graph = helper.make_graph(
nodes = [node],
name = 'TopKGraph',
inputs = input_tensors,
outputs = output_tensors
)

# Create the model
model = helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 1)]
)

# Check the model
onnx.checker.check_model(model)

# Save the model to a file
onnx.save(model, 'top_k.onnx')

print("Model saved to topk_model.onnx")
6 changes: 5 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -114,6 +115,7 @@ pub enum Node<PS: PrecisionSettings> {
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
TopK(TopKNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -162,6 +164,7 @@ macro_rules! match_all {
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::TopK(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -218,6 +221,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::TopK(_) => "top_k",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod top_k;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
112 changes: 112 additions & 0 deletions crates/burn-import/src/burn/node/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};

#[derive(Config, Debug)]
pub struct TopKConfig {
pub axis: i64,
pub k: i64,
}

#[derive(Debug, Clone, new)]
pub struct TopKNode {
pub input: TensorType,
pub outputs: Vec<TensorType>,
pub config: TopKConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode {
fn output_types(&self) -> Vec<Type> {
self.outputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let axis = self.config.axis.to_token_stream();
let k = self.config.k.to_token_stream();

let input = scope.tensor_use_owned(&self.input, node_position);
let values_output = &self.outputs[0].name;
let indices_output = &self.outputs[1].name;

quote! {
let (#values_output, #indices_output) = #input.topk_with_indices(#k as usize, #axis as usize);
}
}

fn into_node(self) -> Node<PS> {
Node::TopK(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, top_k::TopKNode},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TopKConfig::new(-1, 3);

graph.register(TopKNode::new(
TensorType::new_float("input_tensor", 4),
vec![
TensorType::new_float("values_tensor", 4),
TensorType::new_int("indices_tensor", 4),
],
config,
));

graph.register_input_output(
vec!["input_tensor".to_string()],
vec!["values_tensor".to_string(), "indices_tensor".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) {
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3i64 as usize, -1i64 as usize);
oojo12 marked this conversation as resolved.
Show resolved Hide resolved
(values_tensor, indices_tensor)
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
14 changes: 13 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};

use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};

/// Create a Conv1dConfig from the attributes of the node
Expand Down Expand Up @@ -795,6 +795,18 @@ pub fn tile_config(node: &Node) -> TileConfig {
TileConfig::new(repeat)
}

fn extract_attr_value_i64(node: &Node, key: &str) -> i64 {
let value = node.attrs.get(key).unwrap().clone().into_i64();
value
}

/// Create a TopKConfig from the attributes of the node.
pub fn top_k_config(node: &Node) -> TopKConfig {
let axis: i64 = extract_attr_value_i64(node, "axis");
let k: i64 = extract_attr_value_i64(node, "k");
TopKConfig::new(axis, k)
}
oojo12 marked this conversation as resolved.
Show resolved Hide resolved

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads_input(node: &Node) -> Vec<i64> {
Expand Down
17 changes: 15 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use crate::{
squeeze::SqueezeNode,
sum::SumNode,
tile::TileNode,
top_k::TopKNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
},
Expand All @@ -67,8 +68,8 @@ use super::op_configuration::{
hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config,
max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config,
reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config,
shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config,
unsqueeze_config,
shape_config, slice_config, softmax_config, squeeze_config, tile_config, top_k_config,
transpose_config, unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
Expand Down Expand Up @@ -338,6 +339,7 @@ impl ParsedOnnxGraph {
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)),
NodeType::Tile => graph.register(Self::tile_conversion(node)),
NodeType::TopK => graph.register(Self::top_k_conversion(node)),
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
Expand Down Expand Up @@ -1184,6 +1186,17 @@ impl ParsedOnnxGraph {

TileNode::new(input, output, config)
}

fn top_k_conversion(node: Node) -> TopKNode {
// Inputs
let input = TensorType::from(node.inputs.first().unwrap());

// Outputs
let outputs = node.outputs.iter().map(TensorType::from).collect();
let config = top_k_config(&node);

TopKNode::new(input, outputs, config)
}
}

/// Extract data from node states and convert it to `TensorData`.
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/tests/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod tests {
use burn_tensor::{Distribution, Tensor};

#[test]
#[cfg(target_os = "macos")]
oojo12 marked this conversation as resolved.
Show resolved Hide resolved
fn tanh_should_not_have_numerical_bugs_on_macos() {
fn tanh_one_value(input: f32) -> f32 {
let tensor = Tensor::<TestBackend, 1>::ones([1], &Default::default()) * input;
Expand Down
30 changes: 30 additions & 0 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Sub => same_as_input_broadcast(node),
NodeType::Sum => same_as_input_broadcast(node),
NodeType::Tanh => same_as_input(node),
NodeType::TopK => top_k_update_output(node),
NodeType::Transpose => same_as_input(node),
NodeType::Unsqueeze => unsqueeze_update_output(node),
NodeType::Where => where_update_outputs(node),
Expand Down Expand Up @@ -477,6 +478,35 @@ fn same_as_input(node: &mut Node) {
node.outputs[0].ty = node.inputs[0].ty.clone();
}

fn top_k_update_output(node: &mut Node) {
let dim = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("TopK: invalid input type"),
};

let output_values_elem = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.elem_type.clone(),
_ => panic!("TopK: invalid output type"),
};

let output_indices_elem = match &node.outputs[1].ty {
ArgType::Tensor(_) => ElementType::Int64,
_ => panic!("TopK: invalid output type"),
};

node.outputs[0].ty = ArgType::Tensor(TensorType {
dim,
shape: None, // shape is tracked and calculated at runtime
elem_type: output_values_elem,
});

node.outputs[1].ty = ArgType::Tensor(TensorType {
dim,
shape: None, // shape is tracked and calculated at runtime
elem_type: output_indices_elem,
});
}

/// Temporary pass-through stub for dimension inference so that we can export the IR model.
fn temporary_pass_through_stub(node: &mut Node) {
log::warn!("Must implement dimension inference for {:?}", node);
Expand Down
Loading