Skip to content

Commit

Permalink
Convolution support
Browse files Browse the repository at this point in the history
  • Loading branch information
aryavohra committed Oct 4, 2024
1 parent 8605004 commit 9c8ecbf
Show file tree
Hide file tree
Showing 10 changed files with 686 additions and 62 deletions.
526 changes: 489 additions & 37 deletions src/enzyme_ad/jax/Passes/EqualitySaturation.cpp

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions src/enzyme_ad/jax/Passes/EqualitySaturation.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace tensat {
enum class Type : uint8_t;
enum class Ops : uint8_t;
struct Vector;
struct Matrix;
struct Tensor;

/**
Expand All @@ -17,12 +18,14 @@ struct Tensor;

uint64_t get_cost(Ops op, rust::Vec<tensat::Tensor> operands,
rust::Vec<tensat::Vector> other_vector_args,
rust::Vec<int64_t> int_args);
rust::Vec<int64_t> int_args,
rust::Vec<tensat::Matrix> matrix_args);

mlir::Type newTensorType(mlir::OpBuilder &builder, Tensor tensor);
mlir::Type tensatTypeToMlirType(mlir::OpBuilder &builder, Type type);

rust::Vec<Tensor> get_shape(Ops op, rust::Vec<tensat::Tensor> operands,
rust::Vec<tensat::Vector> other_vector_args,
rust::Vec<int64_t> int_args);
rust::Vec<int64_t> int_args,
rust::Vec<tensat::Matrix> matrix_args);
} // namespace tensat
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ dependencies = [
[[package]]
name = "egg"
version = "0.6.1-dev"
source = "git+https://github.com/yycdavid/egg?rev=12cc1ee7731d37fe91901c81f59678fa1d08a2bb#12cc1ee7731d37fe91901c81f59678fa1d08a2bb"
source = "git+https://github.com/aryavohra/egg?rev=b30d14cff61bff97336323f6eb0978cc7769140d#b30d14cff61bff97336323f6eb0978cc7769140d"
dependencies = [
"indexmap",
"instant",
Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/deps/tensat/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/enzyme_ad/jax/deps/tensat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }

[dependencies.egg]
git = "https://github.com/yycdavid/egg"
rev = "12cc1ee7731d37fe91901c81f59678fa1d08a2bb"
git = "https://github.com/aryavohra/egg"
rev = "b30d14cff61bff97336323f6eb0978cc7769140d"

[package.metadata.cxx]
library = "c++"
Expand Down
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/deps/tensat/converted.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@
(ConcatenateOp (Vec (MulOp ?x ?y) (MulOp ?z ?w)) ?i)<=>(MulOp (ConcatenateOp (Vec ?x ?z) ?i) (ConcatenateOp (Vec ?y ?w) ?i))

(ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?y) 1) (ConcatenateOp (Vec ?z ?w) 1)) 0)<=>(ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?z) 0) (ConcatenateOp (Vec ?y ?w) 0)) 1)

(ConvolutionOp (MulOp ?x ?w) ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConvolutionOp ?x (MulOp ?y ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)

(ConvolutionOp ?lhs (ConcatenateOp (Vec ?x ?y) ?i) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConcatenateOp (Vec (ConvolutionOp ?lhs ?x ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) (ConvolutionOp ?lhs ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)) ?i)

(ConvolutionOp ?lhs (MulOp ?rhs ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(MulOp (ConvolutionOp ?lhs ?rhs ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) ?w)
21 changes: 15 additions & 6 deletions src/enzyme_ad/jax/deps/tensat/src/ffi_utils.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
use crate::{
input::ffi,
model::*,
rewrites::{get_num_option, get_vec_of_nums_option, get_vec_option},
rewrites::{get_matrix_option, get_num_option, get_vec_of_nums_option, get_vec_option},
};
use egg::*;

fn process_enode_args(
egraph: &EGraph<Mdl, TensorAnalysis>,
enode: &Mdl,
) -> (Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>) {
) -> (
Vec<ffi::Tensor>,
Vec<ffi::Vector>,
Vec<i64>,
Vec<ffi::Matrix>,
) {
let mut args: Vec<ffi::Tensor> = vec![];
let mut other_vecs: Vec<ffi::Vector> = vec![];
let mut int_args: Vec<i64> = vec![];
let mut matrix_args: Vec<ffi::Matrix> = vec![];

for child in enode.children().iter() {
if let Some(other_vec) = get_vec_of_nums_option(egraph, &egraph[*child]) {
other_vecs.push(other_vec)
} else if let Some(mat) = get_matrix_option(egraph, &egraph[*child]) {
println!("{:?}", mat);
matrix_args.push(mat)
} else if let Some(vec) = get_vec_option(&egraph[*child]) {
vec.iter()
.for_each(|&id| args.push(egraph[id].data.tensors[0].clone()))
Expand All @@ -27,7 +36,7 @@ fn process_enode_args(
}
}

(args, other_vecs, int_args)
(args, other_vecs, int_args, matrix_args)
}

pub fn create_stablehlo_op<F, R>(
Expand All @@ -36,10 +45,10 @@ pub fn create_stablehlo_op<F, R>(
process_output: F,
) -> R
where
F: Fn(ffi::Ops, Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>) -> R,
F: Fn(ffi::Ops, Vec<ffi::Tensor>, Vec<ffi::Vector>, Vec<i64>, Vec<ffi::Matrix>) -> R,
{
let op = ffi::Ops::from_mdl(enode);
let (args, other_vecs, int_args) = process_enode_args(egraph, enode);
let res = process_output(op, args, other_vecs, int_args);
let (args, other_vecs, int_args, matrix_args) = process_enode_args(egraph, enode);
let res = process_output(op, args, other_vecs, int_args, matrix_args);
res
}
124 changes: 116 additions & 8 deletions src/enzyme_ad/jax/deps/tensat/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod ffi {
SelectOp,
ConcatenateOp,
DotGeneralOp,
ConvolutionOp,
PadOp,
SliceOp,
TransposeOp,
Expand Down Expand Up @@ -83,6 +84,12 @@ pub mod ffi {
pub vec: Vec<i64>,
}

// Similarly, we're creating a Matrix type for vecs of vecs (padding)
#[derive(Debug)]
struct Matrix {
pub mat: Vec<Vector>,
}

// take floats from c++ and wrap them into f32s below
extern "Rust" {
type Mdl;
Expand Down Expand Up @@ -159,6 +166,29 @@ pub mod ffi {
dimension: i64,
output: Tensor,
) -> Box<TensorInfo>;
fn new_convolution_op(
self: &mut CppGraphConverter,
lhs: &TensorInfo,
rhs: &TensorInfo,
windowStrides: Vec<i64>,
padding: Vec<Vector>,
lhsDilation: Vec<i64>,
rhsDilation: Vec<i64>,
windowReversal: Vec<bool>,
inputBatchDimension: i64,
inputFeatureDimension: i64,
inputSpatialDimension: Vec<i64>,
kernelInputFeatureDimension: i64,
kernelOutputFeatureDimension: i64,
kernelSpatialDimension: Vec<i64>,
outputBatchDimension: i64,
outputFeatureDimension: i64,
outputSpatialDimension: Vec<i64>,
featureGroupCount: i64,
batchGroupCount: i64,
precision_config: Vec<i64>,
output: Tensor,
) -> Box<TensorInfo>;
fn new_dot_general_op(
self: &mut CppGraphConverter,
lhs: &TensorInfo,
Expand Down Expand Up @@ -274,7 +304,7 @@ pub mod ffi {
fn new_blackbox_op(
self: &mut CppGraphConverter,
inpts: &[*mut TensorInfo],
captured: &[*mut TensorInfo], // values that appear in a block that was declared outside
captured: &[*mut TensorInfo], // values that appear in a block that was declared outside
cpp_num: i64,
outputs: &Vec<Tensor>,
) -> Box<TensorInfo>;
Expand All @@ -293,6 +323,7 @@ pub mod ffi {
operands: Vec<Tensor>,
other_vector_args: Vec<Vector>,
int_args: Vec<i64>,
matrix_args: Vec<Matrix>,
) -> u64;
}

Expand All @@ -304,6 +335,7 @@ pub mod ffi {
operands: Vec<Tensor>,
other_vector_args: Vec<Vector>,
int_args: Vec<i64>,
matrix_args: Vec<Matrix>,
) -> Vec<Tensor>;
}
}
Expand Down Expand Up @@ -356,6 +388,7 @@ impl ffi::Ops {
Mdl::PadOp(_) => Ops::PadOp,
Mdl::SliceOp(_) => Ops::SliceOp,
Mdl::TransposeOp(_) => Ops::TransposeOp,
Mdl::ConvolutionOp(_) => Ops::ConvolutionOp,
Mdl::MulOp(_) => Ops::MulOp,
Mdl::AddOp(_) => Ops::AddOp,
Mdl::DivOp(_) => Ops::DivOp,
Expand Down Expand Up @@ -601,10 +634,7 @@ impl CppGraphConverter {
Box::new(res)
}

fn new_tensorinfo_vec(
&mut self,
inputs: &[*mut TensorInfo]
) -> Id {
fn new_tensorinfo_vec(&mut self, inputs: &[*mut TensorInfo]) -> Id {
let tensor_infos: Vec<&TensorInfo> = inputs.iter().map(|&ptr| unsafe { &*ptr }).collect();
let inputs_node = Mdl::Vec(tensor_infos.iter().map(|i| i.id).collect());
self.rec_expr.add(inputs_node)
Expand All @@ -630,6 +660,79 @@ impl CppGraphConverter {
Box::new(res)
}

pub fn new_convolution_op(
&mut self,
lhs: &TensorInfo,
rhs: &TensorInfo,
window_strides: Vec<i64>,
padding: Vec<ffi::Vector>,
lhs_dilation: Vec<i64>,
rhs_dilation: Vec<i64>,
window_reversal: Vec<bool>,
input_batch_dimension: i64,
input_feature_dimension: i64,
input_spatial_dimensions: Vec<i64>,
kernel_input_feature_dimension: i64,
kernel_output_feature_dimension: i64,
kernel_spatial_dimensions: Vec<i64>,
output_batch_dimension: i64,
output_feature_dimension: i64,
output_spatial_dimensions: Vec<i64>,
feature_group_count: i64,
batch_group_count: i64,
precision_config: Vec<i64>,
output: ffi::Tensor,
) -> Box<TensorInfo> {
let window_strides_node_id = self.vec_node(window_strides);
let lhs_dilation_node_id = self.vec_node(lhs_dilation);
let rhs_dilation_node_id = self.vec_node(rhs_dilation);

// We could add a bool element type vec?
let window_reversal_node_id =
self.vec_node(window_reversal.iter().map(|x| *x as i64).collect());
let input_spatial_dimensions_node_id = self.vec_node(input_spatial_dimensions);
let kernel_spatial_dimensions_node_id = self.vec_node(kernel_spatial_dimensions);
let output_spatial_dimensions_node_id = self.vec_node(output_spatial_dimensions);
let precision_config_node_id = self.vec_node(precision_config);

let padding_node_ids: Vec<Id> = padding
.into_iter()
.map(|pad| self.vec_node(pad.vec))
.collect::<Vec<Id>>();
let padding_node_id = self.rec_expr.add(Mdl::Vec(padding_node_ids));

let new_node = Mdl::ConvolutionOp([
lhs.id,
rhs.id,
window_strides_node_id,
padding_node_id,
lhs_dilation_node_id,
rhs_dilation_node_id,
window_reversal_node_id,
self.add_or_get_val(input_batch_dimension),
self.add_or_get_val(input_feature_dimension),
input_spatial_dimensions_node_id,
self.add_or_get_val(kernel_input_feature_dimension),
self.add_or_get_val(kernel_output_feature_dimension),
kernel_spatial_dimensions_node_id,
self.add_or_get_val(output_batch_dimension),
self.add_or_get_val(output_feature_dimension),
output_spatial_dimensions_node_id,
self.add_or_get_val(feature_group_count),
self.add_or_get_val(batch_group_count),
precision_config_node_id,
]);

let res = TensorInfo {
id: self.rec_expr.add(new_node),
tensor_data: TensorData {
tensors: vec![output],
name: None,
},
};
Box::new(res)
}

pub fn new_dot_general_op(
&mut self,
lhs: &TensorInfo,
Expand Down Expand Up @@ -1043,6 +1146,7 @@ impl CppGraphConverter {
Mdl::DotGeneralOp(ops) => new_node(ops),
Mdl::SliceOp(ops) => new_node(ops),
Mdl::TransposeOp(ops) => new_node(ops),
Mdl::ConvolutionOp(ops) => new_node(ops),
Mdl::MulOp(ops) => new_node(ops),
Mdl::AddOp(ops) => new_node(ops),
Mdl::DivOp(ops) => new_node(ops),
Expand All @@ -1059,7 +1163,7 @@ impl CppGraphConverter {
Mdl::SSplit0(ops) => new_node(ops),
Mdl::SSplit1(ops) => new_node(ops),
Mdl::MatchRank(ops) => new_node(ops),
_ => unimplemented!()
_ => unimplemented!(),
};

res.push(node);
Expand Down Expand Up @@ -1088,7 +1192,8 @@ impl CppGraphConverter {
read_to_string(rule_file).expect("Something went wrong reading the rule file");
let time_limit_sec = Duration::new(n_sec, 0);
let pre_defined_rules = PRE_DEFINED_RULES.iter().map(|&x| x);
let split_rules: Vec<&str> = learned_rules.split("\n")
let split_rules: Vec<&str> = learned_rules
.split("\n")
.filter(|x| !x.is_empty())
.chain(pre_defined_rules)
.collect();
Expand Down Expand Up @@ -1234,7 +1339,10 @@ fn extract_by_ilp(
let class_constraint = true;
let no_order = true;
let initialise_with_greedy = false;
let fusion_costs: bool = std::env::var("FUSION_COSTS").unwrap_or(String::from("true")).parse().unwrap();
let fusion_costs: bool = std::env::var("FUSION_COSTS")
.unwrap_or(String::from("true"))
.parse()
.unwrap();
let mut arg_vec = vec!["src/enzyme_ad/jax/deps/tensat/extractor/extract.py"];
if order_var_int {
arg_vec.push("--order_var_int");
Expand Down
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/deps/tensat/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ define_language! {
"GatherOp" = GatherOp([Id; 10]),
"SelectOp" = SelectOp([Id; 3]), // pred, on_true, on_false
"ConcatenateOp" = ConcatenateOp([Id; 2]), // inputs, dimension
"ConvolutionOp" = ConvolutionOp([Id; 19]), // LOTS of inputs
"DotGeneralOp" = DotGeneralOp([Id; 7]), // lhs, rhs, ..., shape
"PadOp" = PadOp([Id; 5]), // input, padding_value, edge_padding_low,
// edge_padding_high, interior_padding
Expand All @@ -50,7 +51,7 @@ define_language! {
// Complete pain, has arity 12
"ScatterOp" = ScatterOp([Id; 4]), // input, scatter_indices, updates, dimension_numbers
"ReturnOp" = ReturnOp([Id; 1]),
"BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs)
"BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs)
"Vec" = Vec(Vec<Id>),
"Index" = Index([Id; 2]), // index, input. for indexing into ops with multiple result Values.
// SHORTHANDS (not 1:1 with stablehlo)
Expand Down
Loading

0 comments on commit 9c8ecbf

Please sign in to comment.