From 2244169a82f8d3eeaf86a9415fe49efb907d244b Mon Sep 17 00:00:00 2001 From: Lawrence Lim Date: Tue, 10 Sep 2024 09:32:26 -0700 Subject: [PATCH] Lower Tosa Sigmoid to Secret Arith via 3-degree Polynomial Approximation Fixes #942 PiperOrigin-RevId: 672995519 --- lib/Conversion/TosaToSecretArith/BUILD | 51 ++++++ .../TosaToSecretArith/TosaToSecretArith.cpp | 152 ++++++++++++++++++ .../TosaToSecretArith/TosaToSecretArith.h | 20 +++ .../TosaToSecretArith/TosaToSecretArith.td | 22 +++ tests/tosa_to_secret_arith/BUILD | 10 ++ .../tosa_sigmoid_to_arith.mlir | 26 +++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 8 files changed, 284 insertions(+) create mode 100644 lib/Conversion/TosaToSecretArith/BUILD create mode 100644 lib/Conversion/TosaToSecretArith/TosaToSecretArith.cpp create mode 100644 lib/Conversion/TosaToSecretArith/TosaToSecretArith.h create mode 100644 lib/Conversion/TosaToSecretArith/TosaToSecretArith.td create mode 100644 tests/tosa_to_secret_arith/BUILD create mode 100644 tests/tosa_to_secret_arith/tosa_sigmoid_to_arith.mlir diff --git a/lib/Conversion/TosaToSecretArith/BUILD b/lib/Conversion/TosaToSecretArith/BUILD new file mode 100644 index 000000000..9d855d7e6 --- /dev/null +++ b/lib/Conversion/TosaToSecretArith/BUILD @@ -0,0 +1,51 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "TosaToSecretArith", + srcs = ["TosaToSecretArith.cpp"], + hdrs = [ + "TosaToSecretArith.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Analysis/SecretnessAnalysis", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", + ], + alwayslink = 1, +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TosaToSecretArith", + ], + "TosaToSecretArith.h.inc", + ), + ( + ["-gen-pass-doc"], + "TosaToSecretArith.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TosaToSecretArith.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Conversion/TosaToSecretArith/TosaToSecretArith.cpp b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.cpp new file mode 100644 index 000000000..55b53c929 --- /dev/null +++ b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.cpp @@ -0,0 +1,152 @@ +#include "lib/Conversion/TosaToSecretArith/TosaToSecretArith.h" + +#include + +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "llvm/include/llvm/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "tosa-to-secret-arith" + +namespace mlir { +namespace heir { +namespace tosa { + +#define GEN_PASS_DEF_TOSATOSECRETARITH +#include "lib/Conversion/TosaToSecretArith/TosaToSecretArith.h.inc" + +Value createConstantFloat(ImplicitLocOpBuilder &b, double floatValue, + RankedTensorType type) { + auto elementType = type.getElementType(); + + // Create APFloat based on the float type width + APFloat value(0.0); // Default initialization + if (elementType.isF32()) { + value = APFloat(static_cast( + floatValue)); // Convert double to float if necessary + } else if (elementType.isF64()) { + value = APFloat(floatValue); // Use the double value directly + } else { + llvm_unreachable("Expected a valid float type for constant creation"); + } + + auto constantValuesAttr = SplatElementsAttr::get(type, value); + return b.create(constantValuesAttr); +} + +struct ConvertTosaSigmoid : public OpRewritePattern { + private: + DataFlowSolver *solver; + + public: + ConvertTosaSigmoid(DataFlowSolver *solver, mlir::MLIRContext *context) + : OpRewritePattern(context), solver(solver) {} + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::tosa::SigmoidOp op, + PatternRewriter &rewriter) const override { + auto isSecret = [&](Value value) { + auto *operandLookup = solver->lookupState(value); + Secretness operandSecretness = + operandLookup ? operandLookup->getValue() : Secretness(); + return (operandSecretness.isInitialized() && + operandSecretness.getSecretness()); + }; + + // Do not support lowering for non-secret operands + bool operandIsSecret = isSecret(op.getOperand()); + if (!operandIsSecret) { + return failure(); + } + + auto inputTensorType = + dyn_cast(op.getOperand().getType()); + if (!inputTensorType) { + return failure(); + } + + auto dimensions = inputTensorType.getShape(); + auto dataType = inputTensorType.getElementType(); + + // Do not support lowering for non-float types + if (!dyn_cast(dataType)) { + return failure(); + } + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // Calculates -0.004 * x^3 + 0.197 * x + 0.5 + auto rankedTensorType = RankedTensorType::get(dimensions, dataType); + auto coefficientDegreeZero = createConstantFloat(b, 0.5, rankedTensorType); + auto coefficientDegreeOne = createConstantFloat(b, 0.197, rankedTensorType); + auto coefficientDegreeThree = + createConstantFloat(b, -0.004, rankedTensorType); + + auto coefficientMultiplyDegreeOne = + b.create(coefficientDegreeOne, op.getOperand()); + auto calculateDegreeTwo = + b.create(op.getOperand(), op.getOperand()); + auto calculateDegreeThree = + b.create(calculateDegreeTwo, op.getOperand()); + auto coefficientMultiplyDegreeThree = + b.create(calculateDegreeThree, coefficientDegreeThree); + + auto sumDegreeZeroAndOne = b.create( + coefficientDegreeZero, coefficientMultiplyDegreeOne); + auto totalSum = b.create(sumDegreeZeroAndOne, + coefficientMultiplyDegreeThree); + rewriter.replaceOp(op, totalSum); + return success(); + } +}; + +struct TosaToSecretArith + : public impl::TosaToSecretArithBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(); + + auto result = solver.initializeAndRun(module); + + if (failed(result)) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + RewritePatternSet patterns(context); + + patterns.add(&solver, context); + + // Run pattern matching and conversion + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace tosa +} // namespace heir +} // namespace mlir diff --git a/lib/Conversion/TosaToSecretArith/TosaToSecretArith.h b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.h new file mode 100644 index 000000000..096884925 --- /dev/null +++ b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.h @@ -0,0 +1,20 @@ +#ifndef LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_H_ +#define LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tosa { + +#define GEN_PASS_DECL +#include "lib/Conversion/TosaToSecretArith/TosaToSecretArith.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Conversion/TosaToSecretArith/TosaToSecretArith.h.inc" + +} // namespace tosa +} // namespace heir +} // namespace mlir + +#endif // LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_H_ diff --git a/lib/Conversion/TosaToSecretArith/TosaToSecretArith.td b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.td new file mode 100644 index 000000000..86e1cfd57 --- /dev/null +++ b/lib/Conversion/TosaToSecretArith/TosaToSecretArith.td @@ -0,0 +1,22 @@ +#ifndef LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_TD_ +#define LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_TD_ + +include "mlir/Pass/PassBase.td" + +def TosaToSecretArith : Pass<"tosa-to-secret-arith"> { + let summary = "Lower `tosa.sigmoid` to secret arith dialects."; + + let description = [{ + This pass lowers the `tosa.sigmoid` dialect to the polynomial approximation + -0.004 * x^3 + 0.197 * x + 0.5 (composed of arith, affine, and tensor operations). + + This polynomial approximation of sigmoid only works over the range [-5, 5] + and is taken from the paper ['Logisitic regression over encrypted data from + fully homomorphic encryption' by Chen et al.](https://eprint.iacr.org/2018/462.pdf). + }]; + let dependentDialects = [ + "mlir::heir::tensor_ext::TensorExtDialect", + ]; +} + +#endif // LIB_CONVERSION_TOSATOSECRETARITH_TOSATOSECRETARITH_TD_ diff --git a/tests/tosa_to_secret_arith/BUILD b/tests/tosa_to_secret_arith/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/tosa_to_secret_arith/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tosa_to_secret_arith/tosa_sigmoid_to_arith.mlir b/tests/tosa_to_secret_arith/tosa_sigmoid_to_arith.mlir new file mode 100644 index 000000000..eb1a1da14 --- /dev/null +++ b/tests/tosa_to_secret_arith/tosa_sigmoid_to_arith.mlir @@ -0,0 +1,26 @@ +// RUN: heir-opt %s --tosa-to-secret-arith | FileCheck %s + +// CHECK: func.func @test_tosa_sigmoid_to_secret_arith(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[COEFF_0:.*]] = arith.constant dense<5.{{0*}}e-01> : tensor<1x16xf32> +// CHECK-DAG: %[[COEFF_1:.*]] = arith.constant dense<1.97{{0*}}e-01> : tensor<1x16xf32> +// CHECK-DAG: %[[COEFF_3:.*]] = arith.constant dense<-4.{{0*}}e-03> : tensor<1x16xf32> +// CHECK: %[[RET:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK-NEXT: ^bb0(%[[CONVERTED_ARG:.*]]: tensor<1x16xf32>): +// CHECK: %[[COEFF_MUL_DEGREE_1:.*]] = arith.mulf %[[CONVERTED_ARG]], %[[COEFF_1]] +// CHECK: %[[DEGREE_2:.*]] = arith.mulf %[[CONVERTED_ARG]], %[[CONVERTED_ARG]] +// CHECK: %[[DEGREE_3:.*]] = arith.mulf %[[DEGREE_2]], %[[CONVERTED_ARG]] +// CHECK: %[[COEFF_MUL_DEGREE_3:.*]] = arith.mulf %[[DEGREE_3]], %[[COEFF_3]] +// CHECK: %[[SUM_1:.*]] = arith.addf %[[COEFF_MUL_DEGREE_1]], %[[COEFF_0]] +// CHECK: %[[TOTAL_SUM:.*]] = arith.addf %[[SUM_1]], %[[COEFF_MUL_DEGREE_3]] +// CHECK: secret.yield %[[TOTAL_SUM]] : tensor<1x16xf32> +// CHECK: return %[[RET]] : !secret.secret> +module { +func.func @test_tosa_sigmoid_to_secret_arith(%vec : !secret.secret>) -> !secret.secret> { + %out = secret.generic ins (%vec : !secret.secret>) { + ^bb0(%converted_vec: tensor<1x16xf32>): + %0 = tosa.sigmoid %converted_vec : (tensor<1x16xf32>) -> tensor<1x16xf32> + secret.yield %0 : tensor<1x16xf32> + } -> !secret.secret> + return %out : !secret.secret> +} +} diff --git a/tools/BUILD b/tools/BUILD index f69f09c42..ffeac540e 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -46,6 +46,7 @@ cc_binary( "@heir//lib/Conversion/PolynomialToStandard", "@heir//lib/Conversion/SecretToBGV", "@heir//lib/Conversion/SecretToCKKS", + "@heir//lib/Conversion/TosaToSecretArith", "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/CGGI/Transforms", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 14fc6a51c..a9d790f73 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -16,6 +16,7 @@ #include "lib/Conversion/PolynomialToStandard/PolynomialToStandard.h" #include "lib/Conversion/SecretToBGV/SecretToBGV.h" #include "lib/Conversion/SecretToCKKS/SecretToCKKS.h" +#include "lib/Conversion/TosaToSecretArith/TosaToSecretArith.h" #include "lib/Dialect/BGV/IR/BGVDialect.h" #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CGGI/Transforms/Passes.h" @@ -645,6 +646,7 @@ int main(int argc, char **argv) { registerCGGIToTfheRustBoolPasses(); registerSecretToBGVPasses(); registerSecretToCKKSPasses(); + mlir::heir::tosa::registerTosaToSecretArithPasses(); // Interfaces in HEIR secret::registerBufferizableOpInterfaceExternalModels(registry);