diff --git a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/a.out b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/a.out new file mode 100755 index 000000000000..37c15ce432cc Binary files /dev/null and b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/a.out differ diff --git a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/voiceActivityDetection.c b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/voiceActivityDetection.c index ca54d6e91b10..b673267c140c 100644 --- a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/voiceActivityDetection.c +++ b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/CCode/voiceActivityDetection.c @@ -3,7 +3,8 @@ #include #define PI 3.14159265359 -#define INPUT_LENGTH 100000000 +#define INPUT_LENGTH 1000000000 +#define INCREMENT 0.000137 double* getRangeOfVector(double start, int length, double increment); void gain(double* output, const double* input, double multiplier, int length); @@ -64,19 +65,32 @@ void threshold(double* output, const double* input, double thresholdValue, int l } } +inline int sign(int x) { + return (x > 0) - (x < 0); +} + +// Function to count zero crossings int zeroCrossCount(const double* input, int length) { int count = 0; - for (int i = 1; i < length; i++) { - if ((input[i-1] > 0 && input[i] <= 0) || (input[i-1] < 0 && input[i] >= 0)) { - count++; + int previous_sign = 0; + + for (int i = 0; i < length; i++) { + int current_sign = sign(input[i]); + + if (current_sign != 0) { + if (previous_sign != 0 && current_sign != previous_sign) { + count++; + } + previous_sign = current_sign; } } + return count; } int main() { int fs = 1000; - double* input = getRangeOfVector(0, INPUT_LENGTH, 1); + double* input = getRangeOfVector(0, INPUT_LENGTH, INCREMENT); double getMultiplier = 2 * PI * 5; double* getSinDuration = malloc(INPUT_LENGTH * sizeof(double)); @@ -97,10 +111,10 @@ int main() { int zcr = zeroCrossCount(GetThresholdReal, INPUT_LENGTH); - for (int i = 0; i < INPUT_LENGTH; i++) { - printf("%f ", GetThresholdReal[i]); - } - printf("\n"); + // for (int i = 0; i < INPUT_LENGTH; i++) { + // printf("%f ", GetThresholdReal[i]); + // } + // printf("\n"); // Print zero-crossing count printf("Zero-crossing count: %d\n", zcr); diff --git a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/opt_voice.py b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/opt_voice.py new file mode 100644 index 000000000000..c6f295745607 --- /dev/null +++ b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/opt_voice.py @@ -0,0 +1,16 @@ +def main() { + var fs = 1000; + var input = getRangeOfVector(0, 10000, 0.000125); + var sep = getRangeOfVector(0, 1, 0.5); + var pi = 3.14159265359; + var getMultiplier = 2 * pi * 5; + var getSinDuration = gain(input, getMultiplier); + var signal = sin(getSinDuration ); + + var noise = delay(signal, 5); + var noisy_sig = signal + noise; + var threshold = 0.8; + var zeroOpt = zero_cross_threshold_opt(noisy_sig, threshold); + print(noisy_sig); + print(zeroOpt); +} diff --git a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/voiceActivityDetection.py b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/voiceActivityDetection.py index 3a6d44065eb7..c6b7e6ff6973 100644 --- a/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/voiceActivityDetection.py +++ b/mlir/examples/dsp/SimpleBlocks/Output/TryDSPApps/BenchmarkTest/DSP-DSL/voiceActivityDetection.py @@ -1,19 +1,19 @@ def main() { - var fs = 1000; - # var step = 1/fs; - # print(step); - var input = getRangeOfVector(0, 100000000, 1); - var pi = 3.14159265359; - var getMultiplier = 2 * pi * 5; - # print(getMultiplier); - var getSinDuration = gain(input, getMultiplier); - var signal = sin(getSinDuration ); + var fs = 1000; + var input = getRangeOfVector(0, 100000000, 0.000137); + var sep = getRangeOfVector(0, 1, 0.5); + var pi = 3.14159265359; + var getMultiplier = 2 * pi * 5; + var getSinDuration = gain(input, getMultiplier); + var signal = sin(getSinDuration ); - var noise = delay(signal, 5); - var noisy_sig = signal + noise; - var threshold = 0.8; - var GetThresholdReal = threshold( noisy_sig , threshold); - var zcr = zeroCrossCount(GetThresholdReal); - print(GetThresholdReal); - print(zcr); -} \ No newline at end of file + var noise = delay(signal, 5); + var noisy_sig = signal + noise; + var threshold = 0.8; + # print(sep); + var GetThresholdReal = threshold( noisy_sig , threshold); + # print(GetThresholdReal); + var zcr = zeroCrossCount(GetThresholdReal); + print(zcr); + # print(noisy_sig); +} diff --git a/mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td b/mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td index 2339c41b50ce..548f5e984673 100644 --- a/mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td +++ b/mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td @@ -643,25 +643,13 @@ def zeroCrossCountOp : Dsp_Op<"zeroCrossCount" , let arguments = (ins F64Tensor:$lhs); //working -- F64 let results = (outs F64Tensor); - // let results = (outs I64); - // Indicate that the operation has a custom parser and printer method. - // let hasCustomAssemblyFormat = 1; - // let assemblyFormat = [{ - // `(` $input `:` type($input1 , $input2) `)` attr-dict `to` type(results) - // }]; - // Allow building a zeroCrossCountOp with from the one input operands. let builders = [ OpBuilder<(ins "Value":$lhs)> ]; - // Indicate that the operation has a custom parser and printer method. - // let hasCustomAssemblyFormat = 1; - - // Enable registering canonicalization patterns with this operation. - //let hasCanonicalizer = 1; + let hasCanonicalizer = 1; - // let hasVerifier = 1; } @@ -2679,6 +2667,25 @@ def FIRFilterResSymmThresholdUpOptimizedOp : Dsp_Op<"FIRFilterResSymmThresholdUp } +//===----------------------------------------------------------------------===// +// zeroCntOptimizeOp +//===----------------------------------------------------------------------===// + +def zeroCntOptimizeOp : Dsp_Op<"zero_cross_threshold_opt", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "optimze op for zero cross count + threshold op"; + let description = [{ + val between given threshold are bound to 0, otherwise count + 1 if consecutive 2 elements have different sign. + }]; + + let arguments = (ins F64Tensor:$input, F64Tensor:$threshold); + let results = (outs F64Tensor); + + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$threshold)> + ]; +} + #endif // TOY_OPS diff --git a/mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp b/mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp index aa2b10b98b8c..8ce0c3bce833 100644 --- a/mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp +++ b/mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp @@ -879,15 +879,17 @@ mlir::LogicalResult PowOp::verify() { void zeroCrossCountOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value lhs) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - // state.addTypes(builder.getF64Type())); - // state.addTypes(builder.getI64Type()); state.addOperands({lhs}); } /// Infer the output shape of the zeroCrossCountOp, this is required by the /// shape inference interface. void zeroCrossCountOp::inferShapes() { - getResult().setType(getLhs().getType()); + auto tensorInput = getLhs().getType(); + std::vector shapeForOutput; + mlir::TensorType manipulatedType = mlir::RankedTensorType::get( + shapeForOutput, tensorInput.getElementType()); + getResult().setType(manipulatedType); } //===----------------------------------------------------------------------===// @@ -3400,6 +3402,20 @@ void FIRFilterResSymmThresholdUpOptimizedOp::inferShapes() { getResult().setType(manipulatedType); } +//===----------------------------------------------------------------------===// +// zeroCntOptimizeOp +//===----------------------------------------------------------------------===// + +void zeroCntOptimizeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input, mlir::Value threshold) { + state.addTypes({UnrankedTensorType::get(builder.getF64Type())}); + state.addOperands({input, threshold}); +} + +void zeroCntOptimizeOp::inferShapes() { + getResult().setType(getThreshold().getType()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/examples/dsp/SimpleBlocks/mlir/LowerToAffineLoops.cpp b/mlir/examples/dsp/SimpleBlocks/mlir/LowerToAffineLoops.cpp index 69daeff2ca94..6faa9ea79ef3 100644 --- a/mlir/examples/dsp/SimpleBlocks/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/dsp/SimpleBlocks/mlir/LowerToAffineLoops.cpp @@ -6873,7 +6873,7 @@ struct BitwiseAndOpLowering : public ConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: BitwiseAndOp operations +// ToyToAffine RewritePatterns: zeroCrossCountOpLowering operations //===----------------------------------------------------------------------===// struct zeroCrossCountOpLowering : public ConversionPattern { @@ -6897,15 +6897,15 @@ struct zeroCrossCountOpLowering : public ConversionPattern { // output for result type auto tensorType = llvm::cast((*op->result_type_begin())); Type integerType = rewriter.getI64Type(); + auto memrefType = convertTensorToMemRef(tensorType); + + zeroCrossCountOpAdaptor zeroCrossCountOpAdaptor(operands); + auto inputType = llvm::cast(zeroCrossCountOpAdaptor.getLhs().getType()); // allocation & deallocation for the result of this operation // auto memRefType = convertTensorToMemRef(tensorType); // Force the result to be a tensor of size 1 - auto alloc = insertAllocAndDealloc( - MemRefType::get(ArrayRef(1), tensorType.getElementType()), loc, - rewriter); - zeroCrossCountOpAdaptor zeroCrossCountOpAdaptor(operands); - DEBUG_PRINT_NO_ARGS(); + auto alloc = insertAllocAndDealloc(memrefType, loc, rewriter); // Define constants Value constant0 = rewriter.create( @@ -6923,7 +6923,7 @@ struct zeroCrossCountOpLowering : public ConversionPattern { Value ub = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), - tensorType.getShape()[0])); + inputType.getShape()[0])); Value step = rewriter.create(loc, 1); // Set up for loop @@ -7002,13 +7002,14 @@ struct zeroCrossCountOpLowering : public ConversionPattern { Value finalCountArgFloat = rewriter.create( loc, rewriter.getF64Type(), finalCountArg); - rewriter.create(loc, finalCountArgFloat, alloc, Indx0); + rewriter.create(loc, finalCountArgFloat, alloc, ValueRange{}); rewriter.replaceOp(op, alloc); return success(); }; }; + //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Binary operations //===----------------------------------------------------------------------===// @@ -10367,6 +10368,102 @@ Value constant00 = rewriter.create( } }; +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: zeroCntOptimize operations +//===----------------------------------------------------------------------===// + +struct zeroCntOptimizeOpLowering : public ConversionPattern { + zeroCntOptimizeOpLowering(MLIRContext *ctx) + : ConversionPattern(dsp::zeroCntOptimizeOp::getOperationName(), 1, + ctx) {} +#define DUMP(x) llvm::errs() << "here " << x << "\n"; + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + auto tensorType = llvm::dyn_cast(*op->result_type_begin()); + auto memrefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memrefType, loc, rewriter); + // DUMP("alloc"); + + // acumulator and threshold + Value zero = rewriter.create(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0)); + Value one = rewriter.create(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(1)); + Value negOne = rewriter.create(loc, one); + // DUMP("constant"); + + zeroCntOptimizeOpAdaptor adaptor(operands); + auto inputType = llvm::cast(adaptor.getInput().getType()); + auto thresholdMem = adaptor.getThreshold(); + auto threshold = rewriter.create(loc, thresholdMem, ValueRange{}); + auto negThreshold = rewriter.create(loc, threshold); + + int64_t lb=0, ub=inputType.getShape()[0], step=1; + affine::AffineForOp forOp = rewriter.create(loc, lb, ub, step, ValueRange{zero, zero}); + auto iv = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + + + auto ele = rewriter.create(loc, adaptor.getInput(), ValueRange{iv}); + auto prev_sign = forOp.getBody()->getArgument(1); + auto zero_cnt = forOp.getBody()->getArgument(2); + + // lt threshold + auto lt = rewriter.create(loc, arith::CmpFPredicate::OLE, ele, threshold); + // gt threshold + auto gt = rewriter.create(loc, arith::CmpFPredicate::OGE, ele, negThreshold); + + auto cmp = rewriter.create(loc, lt, gt); + // DUMP("cmp"); + + auto ifOp = rewriter.create(loc, TypeRange{rewriter.getF64Type(), rewriter.getF64Type()}, cmp, true); + // if ele in range, yield stored cnt out + rewriter.setInsertionPointToStart(ifOp.thenBlock()); + // DUMP("1st if"); + rewriter.create(loc, ValueRange{prev_sign, zero_cnt}); + + // else if prev_sign !=0 and cur_sign != prev_sign + rewriter.setInsertionPointToStart(ifOp.elseBlock()); + // DUMP("1st else"); + + auto cur_sign_cmp = rewriter.create(loc, lt, negOne, one); + auto sign_add = rewriter.create(loc, cur_sign_cmp, prev_sign); + auto sign_diff = rewriter.create(loc, arith::CmpFPredicate::OEQ, sign_add, zero); + // DUMP("sign diff"); + + auto valid = rewriter.create(loc, TypeRange{rewriter.getF64Type()}, sign_diff, true); + // DUMP("2nd if"); + rewriter.setInsertionPointToStart(valid.thenBlock()); + auto incre_cnt = rewriter.create(loc, zero_cnt, one); + rewriter.create(loc, ValueRange{incre_cnt}); + rewriter.setInsertionPointToStart(valid.elseBlock()); + // DUMP("2nd else"); + rewriter.create(loc, ValueRange{zero_cnt}); + rewriter.setInsertionPointAfter(valid); + + // DUMP("get result"); + auto cntResult = valid.getResults()[0]; + + rewriter.create(loc, ValueRange{cur_sign_cmp, cntResult}); + rewriter.setInsertionPointAfter(ifOp); + + auto new_sign = ifOp.getResults()[0]; + auto new_cnt = ifOp.getResults()[1]; + + rewriter.create(loc, ValueRange{new_sign, new_cnt}); + rewriter.setInsertionPointAfter(forOp); + // DUMP("get answer"); + + auto result = forOp.getResult(1); + rewriter.create(loc, result, alloc, ValueRange{}); + + rewriter.replaceOp(op, alloc); + return mlir::success(); + } + +}; + } // namespace //===----------------------------------------------------------------------===// @@ -10448,7 +10545,7 @@ void ToyToAffineLoweringPass::runOnOperation() { NormalizeOpLowering, AbsOpLowering, MedianFilterOpLowering, LMS2FindPeaksOptimizedOpLowering, FindPeaks2Diff2MeanOptimizedOpLowering, NormLMSFilterResponseOptimizeOpLowering, - FIRFilterResSymmThresholdUpOptimizedOpLowering>(&getContext()); + FIRFilterResSymmThresholdUpOptimizedOpLowering, zeroCntOptimizeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/mlir/examples/dsp/SimpleBlocks/mlir/MLIRGen.cpp b/mlir/examples/dsp/SimpleBlocks/mlir/MLIRGen.cpp index 54085395a422..9467669f5518 100644 --- a/mlir/examples/dsp/SimpleBlocks/mlir/MLIRGen.cpp +++ b/mlir/examples/dsp/SimpleBlocks/mlir/MLIRGen.cpp @@ -404,7 +404,7 @@ class MLIRGenImpl { if (callee == "fftReal") { if (call.getArgs().size() != 1) { emitError(location, - "MLIR codegen encountered an error: dsp.zeroCrossCount " + "MLIR codegen encountered an error: dsp.fftReal " "accepts only 1 arguments"); return nullptr; } @@ -414,7 +414,7 @@ class MLIRGenImpl { if (callee == "fftImag") { if (call.getArgs().size() != 1) { emitError(location, - "MLIR codegen encountered an error: dsp.zeroCrossCount " + "MLIR codegen encountered an error: dsp.fftImg " "accepts only 1 arguments"); return nullptr; } @@ -882,6 +882,7 @@ class MLIRGenImpl { operands[2], operands[3], operands[4]); } + // threshold if (callee == "threshold") { if (call.getArgs().size() != 2) { emitError(location, @@ -892,6 +893,18 @@ class MLIRGenImpl { return builder.create(location, operands[0], operands[1]); } + // zero cross count optimize + if (callee == "zero_cross_threshold_opt") { + if (call.getArgs().size() != 2) { + emitError(location, + "MLIR codegen encountered an error: dsp.zero_cross_threshold_opt " + "accepts only 2 arguments"); + return nullptr; + } + return builder.create(location, operands[0], operands[1]); + } + + // quantization if (callee == "quantization") { if (call.getArgs().size() != 4) { emitError(location, diff --git a/mlir/examples/dsp/SimpleBlocks/mlir/ToyCombine.cpp b/mlir/examples/dsp/SimpleBlocks/mlir/ToyCombine.cpp index 374991cdaf5d..2106421ba2ef 100644 --- a/mlir/examples/dsp/SimpleBlocks/mlir/ToyCombine.cpp +++ b/mlir/examples/dsp/SimpleBlocks/mlir/ToyCombine.cpp @@ -944,6 +944,36 @@ Value input2 = prev_FIRFilterSymmOp->getOperand(1); return mlir::success(); } + }; + +struct SimplifyZTpass : public mlir::OpRewritePattern { + SimplifyZTpass(mlir::MLIRContext *ctx) : OpRewritePattern(ctx, 1) {} + + mlir::LogicalResult + matchAndRewrite(zeroCrossCountOp op, mlir::PatternRewriter &rewriter) const override { + +#define CHECK(x) if(!x) return failure(); +#define REMOVE(x) if(x->use_empty()) rewriter.eraseOp(x); +#define DEBUG(x) {llvm::errs() << "check for " << x << "\n";} +#define PASS llvm::errs() << "pass\n"; + + auto loc = op.getLoc(); + + // pattern -> CHECK() + Operation* thresholdOp = op.getOperand().getDefiningOp(); + CHECK(thresholdOp); + + Value input = thresholdOp->getOperand(0); + Value threshold = thresholdOp->getOperand(1); + + auto zeroCntOpt = rewriter.create(loc, input, threshold); + + rewriter.replaceOp(op, zeroCntOpt); + + REMOVE(thresholdOp); + + return mlir::success(); + } }; @@ -1089,3 +1119,9 @@ void ThresholdUpOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIR results.add(ctx); } } + +void zeroCrossCountOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *ctx) { + if(getEnableCanonicalOpt()) { + results.add(ctx); + } +} diff --git a/mlir/test/Examples/DspExample/test_zerocross.py b/mlir/test/Examples/DspExample/test_zerocross.py new file mode 100644 index 000000000000..8d6e7a3b5cd9 --- /dev/null +++ b/mlir/test/Examples/DspExample/test_zerocross.py @@ -0,0 +1,8 @@ +def main() { + var testcase = [1, -0.000000001, 0.00000001, 0, 0, -3, 0, 0, 0, 2, 0, 0, 2]; + # cross 2 + var threshold = 0; + # var ans = zeroCrossCount(testcase); + var ans = zero_cross_threshold_opt(testcase, threshold); + print(ans); +}