Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/ops step1 : Introduce OpsRegistry to simplify operators development and enable incremental refactoring (#2638) #2644

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use crate::ops::{registry, ComputeNode};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -317,6 +318,8 @@ fn simple_eval_(
)
}
}

let registry = registry()?;
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Expand Down Expand Up @@ -1950,7 +1953,12 @@ fn simple_eval_(
let output = input.sign()?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
op_type => {
let onnx_op = registry.get(op_type)?;
let cn = ComputeNode::new(&node, values);
let (name, value) = onnx_op.eval(&cn)?;
values.insert(name, value);
}
}
}
graph
Expand Down
2 changes: 2 additions & 0 deletions candle-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub mod onnx {
}

pub mod eval;
mod ops;

pub use eval::{dtype, simple_eval};

pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
Expand Down
30 changes: 30 additions & 0 deletions candle-onnx/src/ops/compute_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::onnx::NodeProto;
use candle::Tensor;
use std::collections::HashMap;

//This struct is used to represent a node in the computation graph
//The idea is not to use the NodeProto directly in the computation graph
//On a longer term, this can lead to a more optimized representation of the computation graph.
//For now, it is just a wrapper around the NodeProto and the context
pub struct ComputeNode<'a> {
node_proto: &'a NodeProto,
context: &'a HashMap<String, Tensor>,
}

impl<'a> ComputeNode<'a> {
pub fn new(node_proto: &'a NodeProto, context: &'a HashMap<String, Tensor>) -> Self {
ComputeNode {
node_proto,
context,
}
}

pub fn get_input(&self, index: usize) -> Option<&Tensor> {
let input_name = self.node_proto.input.get(index)?;
self.context.get(input_name)
}

pub fn get_output(&self, index: usize) -> Option<&String> {
self.node_proto.output.get(index)
}
}
1 change: 1 addition & 0 deletions candle-onnx/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod sign;
22 changes: 22 additions & 0 deletions candle-onnx/src/ops/math/sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::ops::compute_node::ComputeNode;
use crate::ops::OnnxOpError::ComputationFailed;
use crate::ops::{OnnxOp, OnnxOpError, OpOutput};

pub(crate) struct Sign;
impl OnnxOp for Sign {
fn eval(&self, node: &ComputeNode) -> Result<OpOutput, OnnxOpError> {
let input = node
.get_input(0)
.ok_or_else(|| ComputationFailed("input 0 not found".to_string()))?;

let output = input
.sign()
.map_err(|err| ComputationFailed(format!("{:?}", err)))?;

let output_name = node
.get_output(0)
.ok_or_else(|| ComputationFailed("output 0 not found".to_string()))?;

Ok((output_name.clone(), output))
}
}
14 changes: 14 additions & 0 deletions candle-onnx/src/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub mod onnxop;
pub use onnxop::{OnnxOp, OnnxOpError, OnnxOpRegistry, OpOutput};

pub mod compute_node;
pub use compute_node::ComputeNode;

mod math;
use math::sign;

pub fn registry() -> Result<OnnxOpRegistry, OnnxOpError> {
let mut registry = OnnxOpRegistry::new();
registry.insert("Sign", Box::new(sign::Sign))?;
Ok(registry)
}
126 changes: 126 additions & 0 deletions candle-onnx/src/ops/onnxop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use crate::ops::ComputeNode;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};

pub type OpOutput = (String, candle::Tensor);

#[derive(Debug, PartialEq, Eq)]
pub enum OnnxOpError {
InvalidInput(String),
ComputationFailed(String),
UnsupportedOp(String),
DuplicateOp(String),
}

impl From<OnnxOpError> for candle::Error {
fn from(e: OnnxOpError) -> Self {
candle::Error::Msg(format!("{:?}", e))
}
}

impl Display for OnnxOpError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
OnnxOpError::InvalidInput(s) => write!(f, "Invalid input: {}", s),
OnnxOpError::ComputationFailed(s) => write!(f, "Computation failed: {}", s),
OnnxOpError::UnsupportedOp(s) => write!(f, "Unsupported op: {}", s),
OnnxOpError::DuplicateOp(s) => write!(f, "Duplicate op: {}", s),
}
}
}

pub trait OnnxOp {
fn eval(&self, node: &ComputeNode) -> Result<OpOutput, OnnxOpError>;
}

#[derive(Default)]
pub struct OnnxOpRegistry {
ops: HashMap<String, Box<dyn OnnxOp>>,
}

impl OnnxOpRegistry {
pub fn new() -> Self {
Self {
ops: HashMap::new(),
}
}
pub fn insert(&mut self, name: &str, op: Box<dyn OnnxOp>) -> Result<(), OnnxOpError> {
match self.ops.entry(name.to_string()) {
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(op);
Ok(())
}
Entry::Occupied(_) => Err(OnnxOpError::DuplicateOp(name.to_string())),
}
}

pub fn get(&self, name: &str) -> Result<&dyn OnnxOp, OnnxOpError> {
match self.ops.get(name) {
Some(op) => Ok(op.as_ref()),
None => Err(OnnxOpError::UnsupportedOp(name.to_string())),
}
}
}

#[cfg(test)]
mod onnxop_registry_tests {
use super::*;
use candle::Device;
#[test]
fn nominal_case() {
//Given
let dummy_op = Box::new(DummyOp);
let mut registry = OnnxOpRegistry::new();

//When
registry.insert("DummyOp", dummy_op).unwrap();
let op = registry.get("DummyOp");

//Then
assert!(op.is_ok());
}

#[test]
fn unsupported_op() {
//Given
let registry = OnnxOpRegistry::new();

//When
let op = registry.get("Foo");

//Then
match op {
Err(OnnxOpError::UnsupportedOp(_)) => {}
_ => panic!("Expected unsupported op error"),
}
}

#[test]
fn duplicate_op() {
//Given
let dummy_op = Box::new(DummyOp);
let mut registry = OnnxOpRegistry::new();
registry.insert("DummyOp", dummy_op).unwrap();

//When
let dummy_op = Box::new(DummyOp);
let result = registry.insert("DummyOp", dummy_op);

//Then
match result {
Err(OnnxOpError::DuplicateOp(_)) => {}
_ => panic!("Expected duplicate op error"),
}
}

struct DummyOp;
impl OnnxOp for DummyOp {
fn eval(&self, _node: &ComputeNode) -> Result<OpOutput, OnnxOpError> {
Ok((
"dummy".to_string(),
candle::Tensor::new(vec![1u8, 1], &Device::Cpu).unwrap(),
))
}
}
}
Loading