From 9c8ecbf48cf72d2a4aadda88bdb1fa0c64fe51f0 Mon Sep 17 00:00:00 2001 From: Arya Vohra Date: Thu, 3 Oct 2024 19:52:37 -0500 Subject: [PATCH] Convolution support --- .../jax/Passes/EqualitySaturation.cpp | 526 ++++++++++++++++-- src/enzyme_ad/jax/Passes/EqualitySaturation.h | 7 +- .../jax/deps/tensat/Cargo.Bazel.lock | 2 +- src/enzyme_ad/jax/deps/tensat/Cargo.lock | 2 +- src/enzyme_ad/jax/deps/tensat/Cargo.toml | 4 +- src/enzyme_ad/jax/deps/tensat/converted.txt | 6 + .../jax/deps/tensat/src/ffi_utils.rs | 21 +- src/enzyme_ad/jax/deps/tensat/src/input.rs | 124 ++++- src/enzyme_ad/jax/deps/tensat/src/model.rs | 3 +- src/enzyme_ad/jax/deps/tensat/src/rewrites.rs | 53 +- 10 files changed, 686 insertions(+), 62 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp b/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp index 4850506f..1644644d 100644 --- a/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp +++ b/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp @@ -430,6 +430,127 @@ std::vector dotGeneralShapeComputation( return shape; } +/** + * Convert a vector of vectors into a DenseIntElementsAttr + */ +mlir::DenseIntElementsAttr +matrixToDenseAttr(const std::vector> &matrix, + mlir::Builder &builder) { + std::vector flattenedResult; + std::vector shape; + + if (matrix.empty()) { + std::cerr << "Error: Input matrix is empty." << std::endl; + return {}; + } + + for (const auto &row : matrix) { + if (row.empty()) { + std::cerr << "Error: Row in matrix is empty." << std::endl; + return {}; + } + + flattenedResult.insert(flattenedResult.end(), row.begin(), row.end()); + shape.push_back(row.size()); + } + + if (shape.size() == 1) { + shape = {static_cast(flattenedResult.size())}; + } else { + int64_t row_size = shape[0]; + for (const auto &row : matrix) { + if (row.size() != row_size) { + std::cerr << "Error: Non-rectangular matrix detected." << std::endl; + return {}; + } + } + } + + int64_t expected_elements = 1; + for (auto dim : shape) { + expected_elements *= dim; + } + + if (flattenedResult.size() != expected_elements) { + std::cerr << "Error: Mismatch between flattened result size and shape." + << " Expected " << expected_elements << " elements but got " + << flattenedResult.size() << "." << std::endl; + return {}; + } + auto type = mlir::RankedTensorType::get(shape, builder.getIntegerType(64)); + return mlir::DenseIntElementsAttr::get(type, flattenedResult); +} + +// https://github.com/jax-ml/jax/blob/b8a066a90790a900d370812263dea51e4b262b43/jax/_src/lax/convolution.py#L351 +// Define ConvDimensionNumbers similar to JAX's implementation +struct ConvDimensionNumbers { + std::vector lhs_spec; // (batch_dim, feature_dim, spatial_dims...) + std::vector rhs_spec; // (out_feature_dim, in_feature_dim, spatial_dims...) + std::vector out_spec; // (batch_dim, feature_dim, spatial_dims...) +}; + +std::vector convolutionShapeComputation( + llvm::ArrayRef lhs_shape, // Input shape + llvm::ArrayRef rhs_shape, // Kernel shape + std::vector &window_strides, // Strides for each spatial dimension + mlir::DenseElementsAttr padding_attr, // Padding as DenseElementsAttr + std::vector &lhs_dilation, // Dilation for input + std::vector &rhs_dilation, // Dilation for kernel + ConvDimensionNumbers dimension_numbers, + int64_t feature_group_count, // Feature group count + int64_t batch_group_count // Batch group count +) { + // Ensure input shape has at least 2 dimensions + assert(lhs_shape.size() >= 2 && "Input shape must have at least 2 dimensions."); + assert(rhs_shape.size() >= 2 && "Kernel shape must have at least 2 dimensions."); + + auto lhs_spec = dimension_numbers.lhs_spec; + auto rhs_spec = dimension_numbers.rhs_spec; + auto out_spec = dimension_numbers.out_spec; + + std::vector lhs_spatial_dims(lhs_spec.begin() + 2, lhs_spec.end()); + std::vector rhs_spatial_dims(rhs_spec.begin() + 2, rhs_spec.end()); + std::vector out_spatial_dims(out_spec.begin() + 2, out_spec.end()); + + std::vector> padding; + auto padding_values = padding_attr.getValues(); + auto it = padding_values.begin(); + while (it != padding_values.end()) { + int64_t low = *it++; + int64_t high = *it++; + padding.emplace_back(low, high); + } + + assert(padding.size() == lhs_spatial_dims.size() && "Padding size does not match number of spatial dimensions."); + + int64_t output_batch_dim = lhs_shape[lhs_spec[0]] / batch_group_count; + int64_t output_feature_dim = rhs_shape[rhs_spec[0]] / feature_group_count; + + std::vector output_shape(lhs_shape.size(), -1); + output_shape[out_spec[0]] = output_batch_dim; + + // Compute spatial dimensions + for (size_t i = 0; i < lhs_spatial_dims.size(); ++i) { + int64_t lhs_dim = lhs_shape[lhs_spatial_dims[i]]; + int64_t rhs_dim = rhs_shape[rhs_spatial_dims[i]]; + + // Apply dilation to input size + int64_t dilated_lhs = (lhs_dim - 1) * lhs_dilation[i] + 1; + + // Apply padding + int64_t padded_lhs = dilated_lhs + padding[i].first + padding[i].second; + + // Effective kernel size after dilation + int64_t dilated_rhs = (rhs_dim - 1) * rhs_dilation[i] + 1; + + // Calculate output dimension + int64_t out_dim = (padded_lhs - dilated_rhs) / window_strides[i] + 1; + output_shape[out_spatial_dims[i]] = out_dim; + } + output_shape[out_spec[1]] = output_feature_dim; + return output_shape; +} + /** * Get the correct start and limiting indices of a SliceOp from a SSplit0 or @@ -490,11 +611,13 @@ Type getReshapeTypeForMatchRank(Value input, Value ref) { return deriveOutputType(input, shape); } -Operation *createStableHloOp(OpBuilder &builder, tensat::Ops op, - SmallVector &operands, - std::vector> &other_vecs, - std::vector &int_args, - MLIRContext *context) { +Operation * +createStableHloOp(OpBuilder &builder, tensat::Ops op, + SmallVector &operands, + std::vector> &other_vecs, + std::vector &int_args, + std::vector>> &matrix_args, + MLIRContext *context) { Operation *mlirOp = nullptr; switch (op) { @@ -551,6 +674,88 @@ Operation *createStableHloOp(OpBuilder &builder, tensat::Ops op, newType, input); break; } + case tensat::Ops::ConvolutionOp: { + auto lhs = operands[0]; + auto rhs = operands[1]; + + auto windowStrides = other_vecs[0]; + auto padding = matrixToDenseAttr(matrix_args[0], builder); + auto lhsDilation = other_vecs[1]; + auto rhsDilation = other_vecs[2]; + auto windowReversal = other_vecs[3]; + std::vector windowReversalBool; + windowReversalBool.reserve(windowReversal.size()); + for (int64_t val : windowReversal) { + windowReversalBool.push_back(static_cast(val != 0)); + } + // dimension numbers + auto inputBatchDimension = int_args[0]; + auto inputFeatureDimension = int_args[1]; + auto inputSpatialDimensions = other_vecs[4]; + auto kernelInputFeatureDimension = int_args[2]; + auto kernelOutputFeatureDimension = int_args[3]; + auto kernelSpatialDimensions = other_vecs[5]; + auto outputBatchDimension = int_args[4]; + auto outputFeatureDimension = int_args[5]; + auto outputSpatialDimensions = other_vecs[6]; + auto convolutionDimensionNumbersAttr = + stablehlo::ConvDimensionNumbersAttr::get( + context, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, + outputSpatialDimensions); + +ConvDimensionNumbers convDims; + convDims.lhs_spec = {inputBatchDimension, inputFeatureDimension}; + convDims.lhs_spec.insert(convDims.lhs_spec.end(), inputSpatialDimensions.begin(), inputSpatialDimensions.end()); + + convDims.rhs_spec = {kernelOutputFeatureDimension, kernelInputFeatureDimension}; + convDims.rhs_spec.insert(convDims.rhs_spec.end(), kernelSpatialDimensions.begin(), kernelSpatialDimensions.end()); + + convDims.out_spec = {outputBatchDimension, outputFeatureDimension}; + convDims.out_spec.insert(convDims.out_spec.end(), outputSpatialDimensions.begin(), outputSpatialDimensions.end()); + auto featureGroupCount = int_args[6]; + auto batchGroupCount = int_args[7]; + auto precisionConfig = other_vecs[7]; + + std::vector precisionVec; + for (auto &precision : precisionConfig) { + switch (precision) { + case 0: + precisionVec.push_back(stablehlo::PrecisionAttr::get( + context, stablehlo::Precision::DEFAULT)); + break; + case 1: + precisionVec.push_back( + stablehlo::PrecisionAttr::get(context, stablehlo::Precision::HIGH)); + break; + case 2: + precisionVec.push_back(stablehlo::PrecisionAttr::get( + context, stablehlo::Precision::HIGHEST)); + break; + } + } + +auto shape = convolutionShapeComputation( + getShape(lhs), getShape(rhs), windowStrides, padding, lhsDilation, + rhsDilation, convDims, featureGroupCount, batchGroupCount); + + auto newType = deriveOutputType(lhs, shape); + mlirOp = builder.create( + builder.getUnknownLoc(), newType, lhs, rhs, + mlir::DenseI64ArrayAttr::get(context, llvm::ArrayRef(windowStrides)), + padding, + mlir::DenseI64ArrayAttr::get(context, llvm::ArrayRef(lhsDilation)), + mlir::DenseI64ArrayAttr::get(context, llvm::ArrayRef(rhsDilation)), + mlir::DenseBoolArrayAttr::get( + context, llvm::ArrayRef(reinterpret_cast( + windowReversalBool.data()), + windowReversalBool.size())), + convolutionDimensionNumbersAttr, featureGroupCount, batchGroupCount, + mlir::ArrayAttr::get(context, llvm::ArrayRef(precisionVec))); + break; + } case tensat::Ops::DotGeneralOp: { std::vector lhs_batch_dim = other_vecs[0]; std::vector rhs_batch_dim = other_vecs[1]; @@ -635,7 +840,8 @@ Operation *createStableHloOp(OpBuilder &builder, tensat::Ops op, // so duplicated) uint64_t tensat::get_cost(tensat::Ops op, rust::Vec enode_args, rust::Vec other_vector_args, - rust::Vec int_args) { + rust::Vec int_args, + rust::Vec matrix_args) { auto context = OperationTimer::getContext(); OpBuilder builder(context); @@ -655,9 +861,23 @@ uint64_t tensat::get_cost(tensat::Ops op, rust::Vec enode_args, for (const auto &num : int_args) int_args_as_vec.push_back(num); + std::vector>> matrix_args_as_vec; + + for (const auto &mat : matrix_args) { + std::vector> current_matrix; + + for (const auto &vec : mat.mat) { + std::vector converted_vec(vec.vec.begin(), vec.vec.end()); + current_matrix.push_back(converted_vec); + } + + matrix_args_as_vec.push_back(current_matrix); + } + // Create the MLIR operation - Operation *mlirOp = createStableHloOp(builder, op, operands, other_vecs, - int_args_as_vec, context); + Operation *mlirOp = + createStableHloOp(builder, op, operands, other_vecs, int_args_as_vec, + matrix_args_as_vec, context); int repeats = 0; switch (getPlatform()) { @@ -727,8 +947,9 @@ std::vector castArrayRefToInt32(llvm::ArrayRef shape) { return dims; } -rust::Vec castArrayRefToRustVec(llvm::ArrayRef vec) { - rust::Vec res; +template +rust::Vec castArrayRefToRustVec(llvm::ArrayRef vec) { + rust::Vec res; res.reserve(vec.size()); for (const auto &elem : vec) { res.push_back(elem); @@ -736,11 +957,55 @@ rust::Vec castArrayRefToRustVec(llvm::ArrayRef vec) { return res; } +rust::Vec +castDenseIntElementsAttrToRustMatrix(mlir::DenseIntElementsAttr attr) { + rust::Vec res; + + // Get the shape of the tensor + auto shape = attr.getType().getShape(); + + if (shape.size() == 2) { + // 2D matrix case + auto numRows = shape[0]; + auto numCols = shape[1]; + + // Iterator for the elements + auto it = attr.value_begin(); + + for (int i = 0; i < numRows; ++i) { + tensat::Vector rowVector; + rowVector.vec.reserve(numCols); + for (int j = 0; j < numCols; ++j) { + // Extract the integer value from APInt and push it into the row vector + rowVector.vec.push_back((*it).getSExtValue()); + ++it; + } + res.push_back( + std::move(rowVector)); // Push the row wrapped in Vector struct + } + } else if (shape.size() == 1) { + // 1D vector case + tensat::Vector rowVector; + for (auto it = attr.value_begin(); + it != attr.value_end(); ++it) { + rowVector.vec.push_back( + (*it).getSExtValue()); // Push elements into the row vector + } + res.push_back( + std::move(rowVector)); // Push the 1D vector wrapped in Vector struct + } else { + llvm::errs() << "Unhandled tensor rank for DenseElementsAttr."; + } + + return res; +} + // SHAPE INFERENCE rust::Vec tensat::get_shape(Ops op, rust::Vec enode_args, rust::Vec other_vector_args, - rust::Vec int_args) { + rust::Vec int_args, + rust::Vec matrix_args) { auto context = OperationTimer::getContext(); OpBuilder builder(context); @@ -760,9 +1025,20 @@ tensat::get_shape(Ops op, rust::Vec enode_args, for (const auto &num : int_args) int_args_as_vec.push_back(num); + std::vector>> matrix_args_as_vec; + for (const auto &mat : matrix_args) { + std::vector> current_matrix; + for (const auto &vec : mat.mat) { + std::vector converted_vec(vec.vec.begin(), vec.vec.end()); + current_matrix.push_back(converted_vec); + } + matrix_args_as_vec.push_back(current_matrix); + } + // Create the MLIR operation - Operation *mlirOp = createStableHloOp(builder, op, operands, other_vecs, - int_args_as_vec, context); + Operation *mlirOp = + createStableHloOp(builder, op, operands, other_vecs, int_args_as_vec, + matrix_args_as_vec, context); if (mlirOp) { rust::Vec tensors; for (auto res : mlirOp->getResults()) { @@ -987,32 +1263,98 @@ class EqualitySaturationPass } } - if (auto output_tensor = - dot_general.getResult().getType().cast()) { - auto shape = castArrayRefToInt32(output_tensor.getShape()); - auto output_shape_slice = - rust::Slice{shape.data(), shape.size()}; - - tensorInfo = graph - ->new_dot_general_op( - *handleOperandPartial(dot_general.getLhs()), - *handleOperandPartial(dot_general.getRhs()), - castArrayRefToRustVec( - dot_dim_attrs.getLhsBatchingDimensions()), - castArrayRefToRustVec( - dot_dim_attrs.getRhsBatchingDimensions()), - castArrayRefToRustVec( - dot_dim_attrs.getLhsContractingDimensions()), - castArrayRefToRustVec( - dot_dim_attrs.getRhsContractingDimensions()), - precision_configs, - mlirValueToTensatTensor(dot_general.getResult())) - .into_raw(); + tensorInfo = graph + ->new_dot_general_op( + *handleOperandPartial(dot_general.getLhs()), + *handleOperandPartial(dot_general.getRhs()), + castArrayRefToRustVec( + dot_dim_attrs.getLhsBatchingDimensions()), + castArrayRefToRustVec( + dot_dim_attrs.getRhsBatchingDimensions()), + castArrayRefToRustVec( + dot_dim_attrs.getLhsContractingDimensions()), + castArrayRefToRustVec( + dot_dim_attrs.getRhsContractingDimensions()), + precision_configs, + mlirValueToTensatTensor(dot_general.getResult())) + .into_raw(); + } else if (isa(op)) { + auto convolution = cast(op); + auto result_type = + convolution.getResult().getType().cast(); + if (result_type.hasRank()) { + auto shape = result_type.getShape(); + llvm::errs() << "Result shape: ["; + for (auto dim : shape) { + llvm::errs() << dim << " "; + } + llvm::errs() << "]\n"; } else { - std::cout << "EqualitySaturationPass: result of " - "stablehlo::DotGeneralOp has non-tensor type" - << std::endl; + llvm::errs() << "Result is unranked.\n"; } + auto dimNumbers = convolution.getDimensionNumbers(); + mlir::ArrayAttr precision = + convolution.getPrecisionConfig().value_or(mlir::ArrayAttr()); + rust::Vec precision_configs; + for (int i = 0; i < precision.size(); i++) { + auto precisionAttr = + precision[i].dyn_cast(); + if (!precisionAttr) + continue; // Skip if it's not a PrecisionAttr, although such + // attributes should not exist here + mlir::stablehlo::Precision val = precisionAttr.getValue(); + switch (val) { + case mlir::stablehlo::Precision::DEFAULT: + precision_configs.push_back(0); + break; + case mlir::stablehlo::Precision::HIGH: + precision_configs.push_back(1); + break; + case mlir::stablehlo::Precision::HIGHEST: + precision_configs.push_back(2); + break; + } + } + auto windowStridesOpt = convolution.getWindowStrides(); + assert(windowStridesOpt.has_value() && + "Expected window strides to be present."); + auto paddingOpt = convolution.getPadding(); + assert(paddingOpt.has_value() && "Expected padding to be present."); + auto lhsDilationOpt = convolution.getLhsDilation(); + assert(lhsDilationOpt.has_value() && + "Expected LHS dilation to be present."); + auto rhsDilationOpt = convolution.getRhsDilation(); + assert(rhsDilationOpt.has_value() && + "Expected RHS dilation to be present."); + auto windowReversalOpt = convolution.getWindowReversal(); + assert(windowReversalOpt.has_value() && + "Expected Window reversal to be present."); + + tensorInfo = + graph + ->new_convolution_op( + *handleOperandPartial(convolution.getLhs()), + *handleOperandPartial(convolution.getRhs()), + castArrayRefToRustVec(windowStridesOpt.value()), + castDenseIntElementsAttrToRustMatrix(paddingOpt.value()), + castArrayRefToRustVec(lhsDilationOpt.value()), + castArrayRefToRustVec(rhsDilationOpt.value()), + castArrayRefToRustVec(windowReversalOpt.value()), + dimNumbers.getInputBatchDimension(), + dimNumbers.getInputFeatureDimension(), + castArrayRefToRustVec(dimNumbers.getInputSpatialDimensions()), + dimNumbers.getKernelInputFeatureDimension(), + dimNumbers.getKernelOutputFeatureDimension(), + castArrayRefToRustVec( + dimNumbers.getKernelSpatialDimensions()), + dimNumbers.getOutputBatchDimension(), + dimNumbers.getOutputFeatureDimension(), + castArrayRefToRustVec( + dimNumbers.getOutputSpatialDimensions()), + convolution.getFeatureGroupCount(), + convolution.getBatchGroupCount(), precision_configs, + mlirValueToTensatTensor(convolution.getResult())) + .into_raw(); } else if (isa(op)) { auto concat = cast(op); auto output_tensor = concat->getResult(0).getType().cast(); @@ -1171,6 +1513,24 @@ class EqualitySaturationPass return result; } + /** + * Parse the Vec nodes with Vecs (e.g Vec(Vec(128, 128), Vec(128, 128))) + * emitted by tensat node construction. + */ + mlir::DenseIntElementsAttr + parseNumMatrixToDenseAttr(rust::vec &nodes, tensat::Node &seq, + mlir::Builder &builder) { + assert(seq.name == "Vec"); + + std::vector> matrix; + for (auto i : seq.operands) { + assert(i < nodes.size()); + assert(nodes[i].name == "Vec"); + matrix.push_back(parseNumVec(nodes, nodes[i])); + } + return matrixToDenseAttr(matrix, builder); + } + /** * Parse the Num nodes emitted by tensat node construction. * Our protocol is to encode integer values as operand indices. @@ -1299,6 +1659,98 @@ class EqualitySaturationPass } break; } + case Ops::ConvolutionOp: { + auto lhs = opVals[node.operands[0]]; + auto rhs = opVals[node.operands[1]]; + + auto windowStrides = parseNumVec(nodes, nodes[node.operands[2]]); + auto padding = + parseNumMatrixToDenseAttr(nodes, nodes[node.operands[3]], builder); + auto lhsDilation = parseNumVec(nodes, nodes[node.operands[4]]); + auto rhsDilation = parseNumVec(nodes, nodes[node.operands[5]]); + auto windowReversal = parseNumVec(nodes, nodes[node.operands[6]]); + std::vector windowReversalBool; + windowReversalBool.reserve(windowReversal.size()); + for (int64_t val : windowReversal) { + windowReversalBool.push_back(static_cast(val != 0)); + } + // dimension numbers + auto inputBatchDimension = parseNumNode(nodes, nodes[node.operands[7]]); + auto inputFeatureDimension = + parseNumNode(nodes, nodes[node.operands[8]]); + auto inputSpatialDimensions = + parseNumVec(nodes, nodes[node.operands[9]]); + auto kernelInputFeatureDimension = + parseNumNode(nodes, nodes[node.operands[10]]); + auto kernelOutputFeatureDimension = + parseNumNode(nodes, nodes[node.operands[11]]); + auto kernelSpatialDimensions = + parseNumVec(nodes, nodes[node.operands[12]]); + auto outputBatchDimension = + parseNumNode(nodes, nodes[node.operands[13]]); + auto outputFeatureDimension = + parseNumNode(nodes, nodes[node.operands[14]]); + auto outputSpatialDimensions = + parseNumVec(nodes, nodes[node.operands[15]]); + auto convolutionDimensionNumbersAttr = + stablehlo::ConvDimensionNumbersAttr::get( + context, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, + outputSpatialDimensions); + +ConvDimensionNumbers convDims; + convDims.lhs_spec = {inputBatchDimension, inputFeatureDimension}; + convDims.lhs_spec.insert(convDims.lhs_spec.end(), inputSpatialDimensions.begin(), inputSpatialDimensions.end()); + + convDims.rhs_spec = {kernelOutputFeatureDimension, kernelInputFeatureDimension}; + convDims.rhs_spec.insert(convDims.rhs_spec.end(), kernelSpatialDimensions.begin(), kernelSpatialDimensions.end()); + + convDims.out_spec = {outputBatchDimension, outputFeatureDimension}; + convDims.out_spec.insert(convDims.out_spec.end(), outputSpatialDimensions.begin(), outputSpatialDimensions.end()); + auto featureGroupCount = parseNumNode(nodes, nodes[node.operands[16]]); + auto batchGroupCount = parseNumNode(nodes, nodes[node.operands[17]]); + auto precisionConfig = parseNumVec(nodes, nodes[node.operands[18]]); + + std::vector precisionVec; + for (auto &precision : precisionConfig) { + switch (precision) { + case 0: + precisionVec.push_back(stablehlo::PrecisionAttr::get( + context, stablehlo::Precision::DEFAULT)); + break; + case 1: + precisionVec.push_back(stablehlo::PrecisionAttr::get( + context, stablehlo::Precision::HIGH)); + break; + case 2: + precisionVec.push_back(stablehlo::PrecisionAttr::get( + context, stablehlo::Precision::HIGHEST)); + break; + } + } + +auto shape = convolutionShapeComputation( + getShape(lhs), getShape(rhs), windowStrides, padding, lhsDilation, + rhsDilation, convDims, featureGroupCount, batchGroupCount); + + auto newType = deriveOutputType(lhs, shape); + newOp = builder.create( + location, newType, lhs, rhs, + mlir::DenseI64ArrayAttr::get(context, + llvm::ArrayRef(windowStrides)), + padding, + mlir::DenseI64ArrayAttr::get(context, llvm::ArrayRef(lhsDilation)), + mlir::DenseI64ArrayAttr::get(context, llvm::ArrayRef(rhsDilation)), + mlir::DenseBoolArrayAttr::get( + context, llvm::ArrayRef(reinterpret_cast( + windowReversalBool.data()), + windowReversalBool.size())), + convolutionDimensionNumbersAttr, featureGroupCount, batchGroupCount, + mlir::ArrayAttr::get(context, llvm::ArrayRef(precisionVec))); + break; + } case Ops::DotGeneralOp: { auto lhs = opVals[node.operands[0]]; auto rhs = opVals[node.operands[1]]; diff --git a/src/enzyme_ad/jax/Passes/EqualitySaturation.h b/src/enzyme_ad/jax/Passes/EqualitySaturation.h index 96dc5450..1d52c9c0 100644 --- a/src/enzyme_ad/jax/Passes/EqualitySaturation.h +++ b/src/enzyme_ad/jax/Passes/EqualitySaturation.h @@ -9,6 +9,7 @@ namespace tensat { enum class Type : uint8_t; enum class Ops : uint8_t; struct Vector; +struct Matrix; struct Tensor; /** @@ -17,12 +18,14 @@ struct Tensor; uint64_t get_cost(Ops op, rust::Vec operands, rust::Vec other_vector_args, - rust::Vec int_args); + rust::Vec int_args, + rust::Vec matrix_args); mlir::Type newTensorType(mlir::OpBuilder &builder, Tensor tensor); mlir::Type tensatTypeToMlirType(mlir::OpBuilder &builder, Type type); rust::Vec get_shape(Ops op, rust::Vec operands, rust::Vec other_vector_args, - rust::Vec int_args); + rust::Vec int_args, + rust::Vec matrix_args); } // namespace tensat diff --git a/src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock b/src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock index dd319354..063b4d12 100644 --- a/src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock +++ b/src/enzyme_ad/jax/deps/tensat/Cargo.Bazel.lock @@ -207,7 +207,7 @@ dependencies = [ [[package]] name = "egg" version = "0.6.1-dev" -source = "git+https://github.com/yycdavid/egg?rev=12cc1ee7731d37fe91901c81f59678fa1d08a2bb#12cc1ee7731d37fe91901c81f59678fa1d08a2bb" +source = "git+https://github.com/aryavohra/egg?rev=b30d14cff61bff97336323f6eb0978cc7769140d#b30d14cff61bff97336323f6eb0978cc7769140d" dependencies = [ "indexmap", "instant", diff --git a/src/enzyme_ad/jax/deps/tensat/Cargo.lock b/src/enzyme_ad/jax/deps/tensat/Cargo.lock index 9e3810fb..78a76be3 100644 --- a/src/enzyme_ad/jax/deps/tensat/Cargo.lock +++ b/src/enzyme_ad/jax/deps/tensat/Cargo.lock @@ -207,7 +207,7 @@ dependencies = [ [[package]] name = "egg" version = "0.6.1-dev" -source = "git+https://github.com/yycdavid/egg?rev=12cc1ee7731d37fe91901c81f59678fa1d08a2bb#12cc1ee7731d37fe91901c81f59678fa1d08a2bb" +source = "git+https://github.com/aryavohra/egg?rev=b30d14cff61bff97336323f6eb0978cc7769140d#b30d14cff61bff97336323f6eb0978cc7769140d" dependencies = [ "indexmap", "instant", diff --git a/src/enzyme_ad/jax/deps/tensat/Cargo.toml b/src/enzyme_ad/jax/deps/tensat/Cargo.toml index 93242369..11d431f9 100644 --- a/src/enzyme_ad/jax/deps/tensat/Cargo.toml +++ b/src/enzyme_ad/jax/deps/tensat/Cargo.toml @@ -25,8 +25,8 @@ serde_json = "1.0" serde = { version = "1.0", features = ["derive"] } [dependencies.egg] -git = "https://github.com/yycdavid/egg" -rev = "12cc1ee7731d37fe91901c81f59678fa1d08a2bb" +git = "https://github.com/aryavohra/egg" +rev = "b30d14cff61bff97336323f6eb0978cc7769140d" [package.metadata.cxx] library = "c++" diff --git a/src/enzyme_ad/jax/deps/tensat/converted.txt b/src/enzyme_ad/jax/deps/tensat/converted.txt index 36b9a906..b6372f5c 100644 --- a/src/enzyme_ad/jax/deps/tensat/converted.txt +++ b/src/enzyme_ad/jax/deps/tensat/converted.txt @@ -26,3 +26,9 @@ (ConcatenateOp (Vec (MulOp ?x ?y) (MulOp ?z ?w)) ?i)<=>(MulOp (ConcatenateOp (Vec ?x ?z) ?i) (ConcatenateOp (Vec ?y ?w) ?i)) (ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?y) 1) (ConcatenateOp (Vec ?z ?w) 1)) 0)<=>(ConcatenateOp (Vec (ConcatenateOp (Vec ?x ?z) 0) (ConcatenateOp (Vec ?y ?w) 0)) 1) + +(ConvolutionOp (MulOp ?x ?w) ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConvolutionOp ?x (MulOp ?y ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) + +(ConvolutionOp ?lhs (ConcatenateOp (Vec ?x ?y) ?i) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(ConcatenateOp (Vec (ConvolutionOp ?lhs ?x ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) (ConvolutionOp ?lhs ?y ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)) ?i) + +(ConvolutionOp ?lhs (MulOp ?rhs ?w) ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig)<=>(MulOp (ConvolutionOp ?lhs ?rhs ?windowstrides ?padding ?lhsdilation ?rhsdilation ?windowreversal ?inputbatchdimension ?inputfeaturedimension ?inputspatialdimensions ?kernelinputfeaturedimension ?kerneloutputfeaturedimension ?kernelspatialdimensions ?outputbatchdimension ?outputfeaturedimension ?outputspatialdimensions ?featuregroupcount ?batchgroupcount ?precisionconfig) ?w) diff --git a/src/enzyme_ad/jax/deps/tensat/src/ffi_utils.rs b/src/enzyme_ad/jax/deps/tensat/src/ffi_utils.rs index 9ed2e6de..48885770 100644 --- a/src/enzyme_ad/jax/deps/tensat/src/ffi_utils.rs +++ b/src/enzyme_ad/jax/deps/tensat/src/ffi_utils.rs @@ -1,21 +1,30 @@ use crate::{ input::ffi, model::*, - rewrites::{get_num_option, get_vec_of_nums_option, get_vec_option}, + rewrites::{get_matrix_option, get_num_option, get_vec_of_nums_option, get_vec_option}, }; use egg::*; fn process_enode_args( egraph: &EGraph, enode: &Mdl, -) -> (Vec, Vec, Vec) { +) -> ( + Vec, + Vec, + Vec, + Vec, +) { let mut args: Vec = vec![]; let mut other_vecs: Vec = vec![]; let mut int_args: Vec = vec![]; + let mut matrix_args: Vec = vec![]; for child in enode.children().iter() { if let Some(other_vec) = get_vec_of_nums_option(egraph, &egraph[*child]) { other_vecs.push(other_vec) + } else if let Some(mat) = get_matrix_option(egraph, &egraph[*child]) { + println!("{:?}", mat); + matrix_args.push(mat) } else if let Some(vec) = get_vec_option(&egraph[*child]) { vec.iter() .for_each(|&id| args.push(egraph[id].data.tensors[0].clone())) @@ -27,7 +36,7 @@ fn process_enode_args( } } - (args, other_vecs, int_args) + (args, other_vecs, int_args, matrix_args) } pub fn create_stablehlo_op( @@ -36,10 +45,10 @@ pub fn create_stablehlo_op( process_output: F, ) -> R where - F: Fn(ffi::Ops, Vec, Vec, Vec) -> R, + F: Fn(ffi::Ops, Vec, Vec, Vec, Vec) -> R, { let op = ffi::Ops::from_mdl(enode); - let (args, other_vecs, int_args) = process_enode_args(egraph, enode); - let res = process_output(op, args, other_vecs, int_args); + let (args, other_vecs, int_args, matrix_args) = process_enode_args(egraph, enode); + let res = process_output(op, args, other_vecs, int_args, matrix_args); res } diff --git a/src/enzyme_ad/jax/deps/tensat/src/input.rs b/src/enzyme_ad/jax/deps/tensat/src/input.rs index 422e9546..3e8e1364 100644 --- a/src/enzyme_ad/jax/deps/tensat/src/input.rs +++ b/src/enzyme_ad/jax/deps/tensat/src/input.rs @@ -34,6 +34,7 @@ pub mod ffi { SelectOp, ConcatenateOp, DotGeneralOp, + ConvolutionOp, PadOp, SliceOp, TransposeOp, @@ -83,6 +84,12 @@ pub mod ffi { pub vec: Vec, } + // Similarly, we're creating a Matrix type for vecs of vecs (padding) + #[derive(Debug)] + struct Matrix { + pub mat: Vec, + } + // take floats from c++ and wrap them into f32s below extern "Rust" { type Mdl; @@ -159,6 +166,29 @@ pub mod ffi { dimension: i64, output: Tensor, ) -> Box; + fn new_convolution_op( + self: &mut CppGraphConverter, + lhs: &TensorInfo, + rhs: &TensorInfo, + windowStrides: Vec, + padding: Vec, + lhsDilation: Vec, + rhsDilation: Vec, + windowReversal: Vec, + inputBatchDimension: i64, + inputFeatureDimension: i64, + inputSpatialDimension: Vec, + kernelInputFeatureDimension: i64, + kernelOutputFeatureDimension: i64, + kernelSpatialDimension: Vec, + outputBatchDimension: i64, + outputFeatureDimension: i64, + outputSpatialDimension: Vec, + featureGroupCount: i64, + batchGroupCount: i64, + precision_config: Vec, + output: Tensor, + ) -> Box; fn new_dot_general_op( self: &mut CppGraphConverter, lhs: &TensorInfo, @@ -274,7 +304,7 @@ pub mod ffi { fn new_blackbox_op( self: &mut CppGraphConverter, inpts: &[*mut TensorInfo], - captured: &[*mut TensorInfo], // values that appear in a block that was declared outside + captured: &[*mut TensorInfo], // values that appear in a block that was declared outside cpp_num: i64, outputs: &Vec, ) -> Box; @@ -293,6 +323,7 @@ pub mod ffi { operands: Vec, other_vector_args: Vec, int_args: Vec, + matrix_args: Vec, ) -> u64; } @@ -304,6 +335,7 @@ pub mod ffi { operands: Vec, other_vector_args: Vec, int_args: Vec, + matrix_args: Vec, ) -> Vec; } } @@ -356,6 +388,7 @@ impl ffi::Ops { Mdl::PadOp(_) => Ops::PadOp, Mdl::SliceOp(_) => Ops::SliceOp, Mdl::TransposeOp(_) => Ops::TransposeOp, + Mdl::ConvolutionOp(_) => Ops::ConvolutionOp, Mdl::MulOp(_) => Ops::MulOp, Mdl::AddOp(_) => Ops::AddOp, Mdl::DivOp(_) => Ops::DivOp, @@ -601,10 +634,7 @@ impl CppGraphConverter { Box::new(res) } - fn new_tensorinfo_vec( - &mut self, - inputs: &[*mut TensorInfo] - ) -> Id { + fn new_tensorinfo_vec(&mut self, inputs: &[*mut TensorInfo]) -> Id { let tensor_infos: Vec<&TensorInfo> = inputs.iter().map(|&ptr| unsafe { &*ptr }).collect(); let inputs_node = Mdl::Vec(tensor_infos.iter().map(|i| i.id).collect()); self.rec_expr.add(inputs_node) @@ -630,6 +660,79 @@ impl CppGraphConverter { Box::new(res) } + pub fn new_convolution_op( + &mut self, + lhs: &TensorInfo, + rhs: &TensorInfo, + window_strides: Vec, + padding: Vec, + lhs_dilation: Vec, + rhs_dilation: Vec, + window_reversal: Vec, + input_batch_dimension: i64, + input_feature_dimension: i64, + input_spatial_dimensions: Vec, + kernel_input_feature_dimension: i64, + kernel_output_feature_dimension: i64, + kernel_spatial_dimensions: Vec, + output_batch_dimension: i64, + output_feature_dimension: i64, + output_spatial_dimensions: Vec, + feature_group_count: i64, + batch_group_count: i64, + precision_config: Vec, + output: ffi::Tensor, + ) -> Box { + let window_strides_node_id = self.vec_node(window_strides); + let lhs_dilation_node_id = self.vec_node(lhs_dilation); + let rhs_dilation_node_id = self.vec_node(rhs_dilation); + + // We could add a bool element type vec? + let window_reversal_node_id = + self.vec_node(window_reversal.iter().map(|x| *x as i64).collect()); + let input_spatial_dimensions_node_id = self.vec_node(input_spatial_dimensions); + let kernel_spatial_dimensions_node_id = self.vec_node(kernel_spatial_dimensions); + let output_spatial_dimensions_node_id = self.vec_node(output_spatial_dimensions); + let precision_config_node_id = self.vec_node(precision_config); + + let padding_node_ids: Vec = padding + .into_iter() + .map(|pad| self.vec_node(pad.vec)) + .collect::>(); + let padding_node_id = self.rec_expr.add(Mdl::Vec(padding_node_ids)); + + let new_node = Mdl::ConvolutionOp([ + lhs.id, + rhs.id, + window_strides_node_id, + padding_node_id, + lhs_dilation_node_id, + rhs_dilation_node_id, + window_reversal_node_id, + self.add_or_get_val(input_batch_dimension), + self.add_or_get_val(input_feature_dimension), + input_spatial_dimensions_node_id, + self.add_or_get_val(kernel_input_feature_dimension), + self.add_or_get_val(kernel_output_feature_dimension), + kernel_spatial_dimensions_node_id, + self.add_or_get_val(output_batch_dimension), + self.add_or_get_val(output_feature_dimension), + output_spatial_dimensions_node_id, + self.add_or_get_val(feature_group_count), + self.add_or_get_val(batch_group_count), + precision_config_node_id, + ]); + + let res = TensorInfo { + id: self.rec_expr.add(new_node), + tensor_data: TensorData { + tensors: vec![output], + name: None, + }, + }; + Box::new(res) + } + pub fn new_dot_general_op( &mut self, lhs: &TensorInfo, @@ -1043,6 +1146,7 @@ impl CppGraphConverter { Mdl::DotGeneralOp(ops) => new_node(ops), Mdl::SliceOp(ops) => new_node(ops), Mdl::TransposeOp(ops) => new_node(ops), + Mdl::ConvolutionOp(ops) => new_node(ops), Mdl::MulOp(ops) => new_node(ops), Mdl::AddOp(ops) => new_node(ops), Mdl::DivOp(ops) => new_node(ops), @@ -1059,7 +1163,7 @@ impl CppGraphConverter { Mdl::SSplit0(ops) => new_node(ops), Mdl::SSplit1(ops) => new_node(ops), Mdl::MatchRank(ops) => new_node(ops), - _ => unimplemented!() + _ => unimplemented!(), }; res.push(node); @@ -1088,7 +1192,8 @@ impl CppGraphConverter { read_to_string(rule_file).expect("Something went wrong reading the rule file"); let time_limit_sec = Duration::new(n_sec, 0); let pre_defined_rules = PRE_DEFINED_RULES.iter().map(|&x| x); - let split_rules: Vec<&str> = learned_rules.split("\n") + let split_rules: Vec<&str> = learned_rules + .split("\n") .filter(|x| !x.is_empty()) .chain(pre_defined_rules) .collect(); @@ -1234,7 +1339,10 @@ fn extract_by_ilp( let class_constraint = true; let no_order = true; let initialise_with_greedy = false; - let fusion_costs: bool = std::env::var("FUSION_COSTS").unwrap_or(String::from("true")).parse().unwrap(); + let fusion_costs: bool = std::env::var("FUSION_COSTS") + .unwrap_or(String::from("true")) + .parse() + .unwrap(); let mut arg_vec = vec!["src/enzyme_ad/jax/deps/tensat/extractor/extract.py"]; if order_var_int { arg_vec.push("--order_var_int"); diff --git a/src/enzyme_ad/jax/deps/tensat/src/model.rs b/src/enzyme_ad/jax/deps/tensat/src/model.rs index 316371f3..bc246cd8 100644 --- a/src/enzyme_ad/jax/deps/tensat/src/model.rs +++ b/src/enzyme_ad/jax/deps/tensat/src/model.rs @@ -26,6 +26,7 @@ define_language! { "GatherOp" = GatherOp([Id; 10]), "SelectOp" = SelectOp([Id; 3]), // pred, on_true, on_false "ConcatenateOp" = ConcatenateOp([Id; 2]), // inputs, dimension + "ConvolutionOp" = ConvolutionOp([Id; 19]), // LOTS of inputs "DotGeneralOp" = DotGeneralOp([Id; 7]), // lhs, rhs, ..., shape "PadOp" = PadOp([Id; 5]), // input, padding_value, edge_padding_low, // edge_padding_high, interior_padding @@ -50,7 +51,7 @@ define_language! { // Complete pain, has arity 12 "ScatterOp" = ScatterOp([Id; 4]), // input, scatter_indices, updates, dimension_numbers "ReturnOp" = ReturnOp([Id; 1]), - "BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs) + "BlackBox" = BlackBox([Id; 3]), // id, args, captured values (last two should be vecs) "Vec" = Vec(Vec), "Index" = Index([Id; 2]), // index, input. for indexing into ops with multiple result Values. // SHORTHANDS (not 1:1 with stablehlo) diff --git a/src/enzyme_ad/jax/deps/tensat/src/rewrites.rs b/src/enzyme_ad/jax/deps/tensat/src/rewrites.rs index 898a6cc0..9e0d87c7 100644 --- a/src/enzyme_ad/jax/deps/tensat/src/rewrites.rs +++ b/src/enzyme_ad/jax/deps/tensat/src/rewrites.rs @@ -102,7 +102,13 @@ pub fn rules>() -> Vec> { vec![ rw!("-concatenation-and-pooling-2" ;"(poolmax ?kx ?ky ?sx ?sy ?p (concat 1 ?x ?y))" => "(concat 1 (poolmax ?kx ?ky ?sx ?sy ?p ?x) (poolmax ?kx ?ky ?sx ?sy ?p ?y)) " ), ]} -fn add_to_rule_vec(rule_vec: &mut Vec>, filter_after: bool, rule_name: String, lhs: &str, rhs: &str) { +fn add_to_rule_vec( + rule_vec: &mut Vec>, + filter_after: bool, + rule_name: String, + lhs: &str, + rhs: &str, +) { let lhs: Pattern = lhs.parse().unwrap(); let rhs: Pattern = rhs.parse().unwrap(); rule_vec.push(rw!(rule_name; { lhs.clone() } => { CheckApply { @@ -117,12 +123,30 @@ pub fn rules_from_str(rs: Vec<&str>, filter_after: bool) -> Vec")) { let eqn: Vec<&str> = rule.split("<=>").collect(); - add_to_rule_vec(&mut rule_vec, filter_after, format!("rule{}_l", pos), eqn[0], eqn[1]); - add_to_rule_vec(&mut rule_vec, filter_after, format!("rule{}_r", pos), eqn[1], eqn[0]); + add_to_rule_vec( + &mut rule_vec, + filter_after, + format!("rule{}_l", pos), + eqn[0], + eqn[1], + ); + add_to_rule_vec( + &mut rule_vec, + filter_after, + format!("rule{}_r", pos), + eqn[1], + eqn[0], + ); } else { assert!(rule.contains("=>")); let eqn: Vec<&str> = rule.split("=>").collect(); - add_to_rule_vec(&mut rule_vec, filter_after, format!("rule{}", pos), eqn[0], eqn[1]); + add_to_rule_vec( + &mut rule_vec, + filter_after, + format!("rule{}", pos), + eqn[0], + eqn[1], + ); } } rule_vec @@ -152,6 +176,27 @@ pub static PRE_DEFINED_MULTI: &[&str] = &[ // TODO: We should really clean these. Just dirty hacks for now to test out conditional rewrites +pub fn get_matrix_option( + egraph: &EGraph, + eclass: &EClass, +) -> Option { + // First, we get an optional Vec of Vec from the eclass + get_vec_option(eclass) + .map(|outer_vec| { + // Iterate over each Vec in the outer_vec to convert to a Vector + outer_vec + .iter() + .map(|&id| { + // For each inner vector, get its corresponding vector of numbers + get_vec_of_nums_option(egraph, &egraph[id]) + .map(|ffi_vec| ffi::Vector { vec: ffi_vec.vec }) + }) + .collect::>>() + }) + // If all the vectors were successfully retrieved, wrap them in a Matrix + .and_then(|opt_mat| opt_mat.map(|mat| ffi::Matrix { mat })) +} + pub fn get_vec_of_nums( egraph: &EGraph, eclass: &EClass,