From c8de1ba2577680ab7391e843f7b1a7ac1f823006 Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Wed, 11 Dec 2024 14:43:38 -0800 Subject: [PATCH] Move OpRewritePatterns to the new file. Signed-off-by: Eric Schweitz --- .../Dialect/Quake/CanonicalPatterns.inc | 398 +++++++++++++++++ lib/Optimizer/Dialect/Quake/QuakeOps.cpp | 412 ------------------ 2 files changed, 398 insertions(+), 412 deletions(-) diff --git a/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc b/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc index b646bb3305..7b46909cf8 100644 --- a/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc +++ b/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc @@ -87,4 +87,402 @@ struct FuseConstantToExtractRefPattern } }; +// %4 = quake.concat %2, %3 : (!quake.ref, !quake.ref) -> !quake.veq<2> +// %7 = quake.extract_ref %4[0] : (!quake.veq<2>) -> !quake.ref +// ─────────────────────────────────────────── +// replace all use with %2 +struct ForwardConcatExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ExtractRefOp extract, + PatternRewriter &rewriter) const override { + auto veq = extract.getVeq(); + auto concatOp = veq.getDefiningOp(); + if (concatOp && extract.hasConstantIndex()) { + // Don't run this canonicalization if any of the operands + // to concat are of type veq. + auto concatQubits = concatOp.getQbits(); + for (auto qOp : concatQubits) + if (isa(qOp.getType())) + return failure(); + + // concat only has ref type operands. + auto index = extract.getConstantIndex(); + if (index < concatQubits.size()) { + auto qOpValue = concatQubits[index]; + if (isa(qOpValue.getType())) { + rewriter.replaceOp(extract, {qOpValue}); + return success(); + } + } + } + return failure(); + } +}; + +// %2 = quake.concat %1 : (!quake.ref) -> !quake.veq<1> +// %3 = quake.extract_ref %2[0] : (!quake.veq<1>) -> !quake.ref +// quake.* %3 ... +// ─────────────────────────────────────────── +// quake.* %1 ... +struct ForwardConcatExtractSingleton + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ExtractRefOp extract, + PatternRewriter &rewriter) const override { + if (auto concat = extract.getVeq().getDefiningOp()) + if (concat.getType().getSize() == 1 && extract.hasConstantIndex() && + extract.getConstantIndex() == 0) { + assert(concat.getQbits().size() == 1 && concat.getQbits()[0]); + extract.getResult().replaceUsesWithIf( + concat.getQbits()[0], [&](OpOperand &use) { + if (Operation *user = use.getOwner()) + return isQuakeOperation(user); + return false; + }); + return success(); + } + return failure(); + } +}; + +// %7 = quake.concat %4 : (!quake.veq<2>) -> !quake.veq<2> +// ─────────────────────────────────────────── +// removed +struct ConcatNoOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ConcatOp concat, + PatternRewriter &rewriter) const override { + // Remove concat veq -> veq + // or + // concat ref -> ref + auto qubitsToConcat = concat.getQbits(); + if (qubitsToConcat.size() > 1) + return failure(); + + // We only want to handle veq -> veq here. + if (isa(qubitsToConcat.front().getType())) { + return failure(); + } + + // Do not handle anything where we don't know the sizes. + auto retTy = concat.getResult().getType(); + if (auto veqTy = dyn_cast(retTy)) + if (!veqTy.hasSpecifiedSize()) + // This could be a folded quake.relax_size op. + return failure(); + + rewriter.replaceOp(concat, qubitsToConcat); + return success(); + } +}; + +// %8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, +// !quake.veq<2>) -> !quake.veq +// ─────────────────────────────────────────────────────────── +// %.8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, +// !quake.veq<2>) -> !quake.veq<7> +// %8 = quake.relax_size %.8 : (!quake.veq<7>) -> !quake.veq +struct ConcatSizePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ConcatOp concat, + PatternRewriter &rewriter) const override { + if (concat.getType().hasSpecifiedSize()) + return failure(); + + // Walk the arguments and sum them, if possible. + std::size_t sum = 0; + for (auto opnd : concat.getQbits()) { + if (auto veqTy = dyn_cast(opnd.getType())) { + if (!veqTy.hasSpecifiedSize()) + return failure(); + sum += veqTy.getSize(); + continue; + } + assert(isa(opnd.getType())); + sum++; + } + + // Leans into the relax_size canonicalization pattern. + auto *ctx = rewriter.getContext(); + auto loc = concat.getLoc(); + auto newTy = quake::VeqType::get(ctx, sum); + Value newOp = + rewriter.create(loc, newTy, concat.getQbits()); + auto noSizeTy = quake::VeqType::getUnsized(ctx); + rewriter.replaceOpWithNewOp(concat, noSizeTy, newOp); + return success(); + } +}; + +// %7 = quake.make_struq %5, %6 : (!quake.veq, !quake.veq) -> +// !quake.struq, !quake.veq> +// %8 = quake.get_member %7[1] : (!quake.struq, +// !quake.veq>) -> !quake.veq +// ─────────────────────────────────────────────────────────── +// replace uses of %8 with %6 +struct BypassMakeStruq : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::GetMemberOp getMem, + PatternRewriter &rewriter) const override { + auto makeStruq = getMem.getStruq().getDefiningOp(); + if (!makeStruq) + return failure(); + auto toStrTy = cast(getMem.getStruq().getType()); + std::uint32_t idx = getMem.getIndex(); + Value from = makeStruq.getOperand(idx); + auto toTy = toStrTy.getMembers()[idx]; + if (from.getType() != toTy) + rewriter.replaceOpWithNewOp(getMem, toTy, from); + else + rewriter.replaceOp(getMem, from); + return success(); + } +}; + +// %22 = quake.init_state %1, %2 : (!quake.veq, T) -> !quake.veq +// ──────────────────────────────────────────────────────────────────── +// %.22 = quake.init_state %1, %2 : (!quake.veq, T) -> !quake.veq +// %22 = quake.relax_size %.22 : (!quake.veq) -> !quake.veq +struct ForwardAllocaTypePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::InitializeStateOp initState, + PatternRewriter &rewriter) const override { + if (auto isTy = dyn_cast(initState.getType())) + if (!isTy.hasSpecifiedSize()) { + auto targ = initState.getTargets(); + if (auto targTy = dyn_cast(targ.getType())) + if (targTy.hasSpecifiedSize()) { + auto newInit = rewriter.create( + initState.getLoc(), targTy, targ, initState.getState()); + rewriter.replaceOpWithNewOp(initState, isTy, + newInit); + return success(); + } + } + + // Remove any intervening cast to !cc.ptr> ops. + if (auto stateCast = + initState.getState().getDefiningOp()) + if (auto ptrTy = dyn_cast(stateCast.getType())) { + auto eleTy = ptrTy.getElementType(); + if (auto arrTy = dyn_cast(eleTy)) + if (arrTy.isUnknownSize()) { + rewriter.replaceOpWithNewOp( + initState, initState.getTargets().getType(), + initState.getTargets(), stateCast.getValue()); + return success(); + } + } + return failure(); + } +}; + +// %3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq +// ────────────────────────────────────────────────────────────────────────── +// %.3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq<7> +// %3 = quake.relax_size %.3 : (!quake.veq<7>) -> !quake.veq +struct FixUnspecifiedSubveqPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::SubVeqOp subveq, + PatternRewriter &rewriter) const override { + auto veqTy = dyn_cast(subveq.getType()); + if (veqTy && veqTy.hasSpecifiedSize()) + return failure(); + if (!(subveq.hasConstantLowerBound() && subveq.hasConstantUpperBound())) + return failure(); + auto *ctx = rewriter.getContext(); + std::size_t size = + subveq.getConstantUpperBound() - subveq.getConstantLowerBound() + 1u; + auto szVecTy = quake::VeqType::get(ctx, size); + auto loc = subveq.getLoc(); + auto subv = rewriter.create( + loc, szVecTy, subveq.getVeq(), subveq.getLower(), subveq.getUpper(), + subveq.getRawLower(), subveq.getRawUpper()); + rewriter.replaceOpWithNewOp(subveq, veqTy, subv); + return success(); + } +}; + +// %1 = constant 4 : i64 +// %2 = constant 10 : i64 +// %3 = quake.subveq %0, %1, %2 : (!quake.veq<12>, i64, i64) -> !quake.veq +// ────────────────────────────────────────────────────────────────────────── +// %3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq<7> +struct FuseConstantToSubveqPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::SubVeqOp subveq, + PatternRewriter &rewriter) const override { + if (subveq.hasConstantLowerBound() && subveq.hasConstantUpperBound()) + return failure(); + bool regen = false; + std::int64_t lo = subveq.getConstantLowerBound(); + Value loVal = subveq.getLower(); + if (!subveq.hasConstantLowerBound()) + if (auto olo = cudaq::opt::factory::getIntIfConstant(subveq.getLower())) { + regen = true; + loVal = nullptr; + lo = *olo; + } + + std::int64_t hi = subveq.getConstantUpperBound(); + Value hiVal = subveq.getUpper(); + if (!subveq.hasConstantUpperBound()) + if (auto ohi = cudaq::opt::factory::getIntIfConstant(subveq.getUpper())) { + regen = true; + hiVal = nullptr; + hi = *ohi; + } + + if (!regen) + return failure(); + rewriter.replaceOpWithNewOp( + subveq, subveq.getType(), subveq.getVeq(), loVal, hiVal, lo, hi); + return success(); + } +}; + +// Replace subveq operations that extract the entire original register with the +// original register. +struct RemoveSubVeqNoOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::SubVeqOp subVeqOp, + PatternRewriter &rewriter) const override { + auto origVeq = subVeqOp.getVeq(); + // The original veq size must be known + auto veqType = dyn_cast(origVeq.getType()); + if (!veqType.hasSpecifiedSize()) + return failure(); + if (!(subVeqOp.hasConstantLowerBound() && subVeqOp.hasConstantUpperBound())) + return failure(); + + // If the subveq is the whole register, than the start value must be 0. + if (subVeqOp.getConstantLowerBound() != 0) + return failure(); + + // If the sizes are equal, then replace + if (veqType.getSize() != subVeqOp.getConstantUpperBound() + 1) + return failure(); + + // this subveq is the whole original register, hence a no-op + rewriter.replaceOp(subVeqOp, origVeq); + return success(); + } +}; + +// %11 = quake.init_state %_, %_ : (!quake.veq<2>, T1) -> !quake.veq +// %12 = quake.veq_size %11 : (!quake.veq) -> i64 +// ──────────────────────────────────────────────────────────────────── +// %11 = quake.init_state %_, %_ : (!quake.veq<2>, T1) -> !quake.veq +// %12 = constant 2 : i64 +struct FoldInitStateSizePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::VeqSizeOp veqSize, + PatternRewriter &rewriter) const override { + Value veq = veqSize.getVeq(); + if (auto initState = veq.getDefiningOp()) + if (auto veqTy = + dyn_cast(initState.getTargets().getType())) + if (veqTy.hasSpecifiedSize()) { + std::size_t numQubits = veqTy.getSize(); + rewriter.replaceOpWithNewOp(veqSize, numQubits, + veqSize.getType()); + return success(); + } + return failure(); + } +}; + +// If there is no operation that modifies the wire after it gets unwrapped and +// before it is wrapped, then the wrap operation is a nop and can be +// eliminated. +struct KillDeadWrapPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::WrapOp wrap, + PatternRewriter &rewriter) const override { + if (auto unwrap = wrap.getWireValue().getDefiningOp()) + rewriter.eraseOp(wrap); + return success(); + } +}; + +template +struct MergeRotationPattern : public OpRewritePattern { + using Base = OpRewritePattern; + using Base::Base; + + LogicalResult matchAndRewrite(OP rotate, + PatternRewriter &rewriter) const override { + auto wireTy = quake::WireType::get(rewriter.getContext()); + if (rotate.getTarget(0).getType() != wireTy || + !rotate.getControls().empty()) + return failure(); + assert(!rotate.getNegatedQubitControls()); + auto input = rotate.getTarget(0).template getDefiningOp(); + if (!input || !input.getControls().empty()) + return failure(); + assert(!input.getNegatedQubitControls()); + + // At this point, we have + // %input = quake.rotate %angle1, %wire + // %rotate = quake.rotate %angle2, %input + // Replace those ops with + // %new = quake.rotate (%angle1 + %angle2), %wire + auto loc = rotate.getLoc(); + auto angle1 = input.getParameter(0); + auto angle2 = rotate.getParameter(0); + if (angle1.getType() != angle2.getType()) + return failure(); + auto adjAttr = rotate.getIsAdjAttr(); + auto newAngle = [&]() -> Value { + if (input.isAdj() == rotate.isAdj()) + return rewriter.create(loc, angle1, angle2); + // One is adjoint, so it should be subtracted from the other. + if (input.isAdj()) + return rewriter.create(loc, angle2, angle1); + adjAttr = input.getIsAdjAttr(); + return rewriter.create(loc, angle1, angle2); + }(); + rewriter.replaceOpWithNewOp(rotate, rotate.getResultTypes(), adjAttr, + ValueRange{newAngle}, ValueRange{}, + ValueRange{input.getTarget(0)}, + rotate.getNegatedQubitControlsAttr()); + return success(); + } +}; + +// Forward the argument to a relax_size to the users for all users that are +// quake operations. All quake ops that take a sized veq argument are +// polymorphic on all veq types. If the op is not a quake op, then maintain +// strong typing. +struct ForwardRelaxedSizePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::RelaxSizeOp relax, + PatternRewriter &rewriter) const override { + auto inpVec = relax.getInputVec(); + Value result = relax.getResult(); + bool replaced = true; + result.replaceUsesWithIf(inpVec, [&](OpOperand &use) { + if (Operation *user = use.getOwner()) + return isQuakeOperation(user) && !isa(user); + replaced = false; + return false; + }); + return replaced ? success() : failure(); + }; +}; + } // namespace diff --git a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp index 852e44814b..d2da75d99b 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp @@ -309,79 +309,6 @@ LogicalResult quake::BorrowWireOp::verify() { // Concat //===----------------------------------------------------------------------===// -namespace { -// %7 = quake.concat %4 : (!quake.veq<2>) -> !quake.veq<2> -// ─────────────────────────────────────────── -// removed -struct ConcatNoOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::ConcatOp concat, - PatternRewriter &rewriter) const override { - // Remove concat veq -> veq - // or - // concat ref -> ref - auto qubitsToConcat = concat.getQbits(); - if (qubitsToConcat.size() > 1) - return failure(); - - // We only want to handle veq -> veq here. - if (isa(qubitsToConcat.front().getType())) { - return failure(); - } - - // Do not handle anything where we don't know the sizes. - auto retTy = concat.getResult().getType(); - if (auto veqTy = dyn_cast(retTy)) - if (!veqTy.hasSpecifiedSize()) - // This could be a folded quake.relax_size op. - return failure(); - - rewriter.replaceOp(concat, qubitsToConcat); - return success(); - } -}; - -// %8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, -// !quake.veq<2>) -> !quake.veq -// ─────────────────────────────────────────────────────────── -// %.8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, -// !quake.veq<2>) -> !quake.veq<7> -// %8 = quake.relax_size %.8 : (!quake.veq<7>) -> !quake.veq -struct ConcatSizePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::ConcatOp concat, - PatternRewriter &rewriter) const override { - if (concat.getType().hasSpecifiedSize()) - return failure(); - - // Walk the arguments and sum them, if possible. - std::size_t sum = 0; - for (auto opnd : concat.getQbits()) { - if (auto veqTy = dyn_cast(opnd.getType())) { - if (!veqTy.hasSpecifiedSize()) - return failure(); - sum += veqTy.getSize(); - continue; - } - assert(isa(opnd.getType())); - sum++; - } - - // Leans into the relax_size canonicalization pattern. - auto *ctx = rewriter.getContext(); - auto loc = concat.getLoc(); - auto newTy = quake::VeqType::get(ctx, sum); - Value newOp = - rewriter.create(loc, newTy, concat.getQbits()); - auto noSizeTy = quake::VeqType::getUnsized(ctx); - rewriter.replaceOpWithNewOp(concat, noSizeTy, newOp); - return success(); - }; -}; -} // namespace - void quake::ConcatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); @@ -422,69 +349,6 @@ void printRawIndex(OpAsmPrinter &printer, OP refOp, Value index, printer << rawIndex.getValue(); } -namespace { -// %4 = quake.concat %2, %3 : (!quake.ref, !quake.ref) -> !quake.veq<2> -// %7 = quake.extract_ref %4[0] : (!quake.veq<2>) -> !quake.ref -// ─────────────────────────────────────────── -// replace all use with %2 -struct ForwardConcatExtractPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::ExtractRefOp extract, - PatternRewriter &rewriter) const override { - auto veq = extract.getVeq(); - auto concatOp = veq.getDefiningOp(); - if (concatOp && extract.hasConstantIndex()) { - // Don't run this canonicalization if any of the operands - // to concat are of type veq. - auto concatQubits = concatOp.getQbits(); - for (auto qOp : concatQubits) - if (isa(qOp.getType())) - return failure(); - - // concat only has ref type operands. - auto index = extract.getConstantIndex(); - if (index < concatQubits.size()) { - auto qOpValue = concatQubits[index]; - if (isa(qOpValue.getType())) { - rewriter.replaceOp(extract, {qOpValue}); - return success(); - } - } - } - return failure(); - } -}; - -// %2 = quake.concat %1 : (!quake.ref) -> !quake.veq<1> -// %3 = quake.extract_ref %2[0] : (!quake.veq<1>) -> !quake.ref -// quake.* %3 ... -// ─────────────────────────────────────────── -// quake.* %1 ... -struct ForwardConcatExtractSingleton - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::ExtractRefOp extract, - PatternRewriter &rewriter) const override { - if (auto concat = extract.getVeq().getDefiningOp()) - if (concat.getType().getSize() == 1 && extract.hasConstantIndex() && - extract.getConstantIndex() == 0) { - assert(concat.getQbits().size() == 1 && concat.getQbits()[0]); - extract.getResult().replaceUsesWithIf( - concat.getQbits()[0], [&](OpOperand &use) { - if (Operation *user = use.getOwner()) - return isQuakeOperation(user); - return false; - }); - return success(); - } - return failure(); - } -}; -} // namespace - void quake::ExtractRefOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::GetMemberOp getMem, - PatternRewriter &rewriter) const override { - if (auto makeStruq = - getMem.getStruq().getDefiningOp()) { - auto toStrTy = cast(getMem.getStruq().getType()); - std::uint32_t idx = getMem.getIndex(); - Value from = makeStruq.getOperand(idx); - auto toTy = toStrTy.getMembers()[idx]; - if (from.getType() != toTy) { - rewriter.replaceOpWithNewOp(getMem, toTy, from); - } else { - rewriter.replaceOp(getMem, from); - } - return success(); - } - return failure(); - } -}; -} // namespace - void quake::GetMemberOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); @@ -579,48 +419,6 @@ LogicalResult quake::InitializeStateOp::verify() { return success(); } -namespace { -// %22 = quake.init_state %1, %2 : (!quake.veq, T) -> !quake.veq -// ──────────────────────────────────────────────────────────────────── -// %22' = quake.init_state %1, %2 : (!quake.veq, T) -> !quake.veq -// %22 = quake.relax_size %22' : (!quake.veq) -> !quake.veq -struct ForwardAllocaTypePattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::InitializeStateOp initState, - PatternRewriter &rewriter) const override { - if (auto isTy = dyn_cast(initState.getType())) - if (!isTy.hasSpecifiedSize()) { - auto targ = initState.getTargets(); - if (auto targTy = dyn_cast(targ.getType())) - if (targTy.hasSpecifiedSize()) { - auto newInit = rewriter.create( - initState.getLoc(), targTy, targ, initState.getState()); - rewriter.replaceOpWithNewOp(initState, isTy, - newInit); - return success(); - } - } - - // Remove any intervening cast to !cc.ptr> ops. - if (auto stateCast = - initState.getState().getDefiningOp()) - if (auto ptrTy = dyn_cast(stateCast.getType())) { - auto eleTy = ptrTy.getElementType(); - if (auto arrTy = dyn_cast(eleTy)) - if (arrTy.isUnknownSize()) { - rewriter.replaceOpWithNewOp( - initState, initState.getTargets().getType(), - initState.getTargets(), stateCast.getValue()); - return success(); - } - } - return failure(); - } -}; -} // namespace - void quake::InitializeStateOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); @@ -656,30 +454,6 @@ LogicalResult quake::RelaxSizeOp::verify() { return success(); } -namespace { -// Forward the argument to a relax_size to the users for all users that are -// quake operations. All quake ops that take a sized veq argument are -// polymorphic on all veq types. If the op is not a quake op, then maintain -// strong typing. -struct ForwardRelaxedSizePattern : public RewritePattern { - ForwardRelaxedSizePattern(MLIRContext *context) - : RewritePattern("quake.relax_size", 1, context, {}) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto relax = cast(op); - auto inpVec = relax.getInputVec(); - Value result = relax.getResult(); - result.replaceUsesWithIf(inpVec, [&](OpOperand &use) { - if (Operation *user = use.getOwner()) - return isQuakeOperation(user) && !isa(user); - return false; - }); - return success(); - }; -}; -} // namespace - void quake::RelaxSizeOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); @@ -713,103 +487,6 @@ LogicalResult quake::SubVeqOp::verify() { return success(); } -namespace { -// %3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq -// ───────────────────────────────────────────────────────────────────────────── -// %new3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq<7> -// %3 = quake.relax_size %new3 : (!quake.veq<7>) -> !quake.veq -struct FixUnspecifiedSubveqPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::SubVeqOp subveq, - PatternRewriter &rewriter) const override { - auto veqTy = dyn_cast(subveq.getType()); - if (veqTy && veqTy.hasSpecifiedSize()) - return failure(); - if (!(subveq.hasConstantLowerBound() && subveq.hasConstantUpperBound())) - return failure(); - auto *ctx = rewriter.getContext(); - std::size_t size = - subveq.getConstantUpperBound() - subveq.getConstantLowerBound() + 1u; - auto szVecTy = quake::VeqType::get(ctx, size); - auto loc = subveq.getLoc(); - auto subv = rewriter.create( - loc, szVecTy, subveq.getVeq(), subveq.getLower(), subveq.getUpper(), - subveq.getRawLower(), subveq.getRawUpper()); - rewriter.replaceOpWithNewOp(subveq, veqTy, subv); - return success(); - } -}; - -// %1 = constant 4 : i64 -// %2 = constant 10 : i64 -// %3 = quake.subveq %0, %1, %2 : (!quake.veq<12>, i64, i64) -> !quake.veq -// ───────────────────────────────────────────────────────────────────────────── -// %3 = quake.subveq %0, 4, 10 : (!quake.veq<12>, i64, i64) -> !quake.veq<7> -struct FuseConstantToSubveqPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::SubVeqOp subveq, - PatternRewriter &rewriter) const override { - if (subveq.hasConstantLowerBound() && subveq.hasConstantUpperBound()) - return failure(); - bool regen = false; - std::int64_t lo = subveq.getConstantLowerBound(); - Value loVal = subveq.getLower(); - if (!subveq.hasConstantLowerBound()) - if (auto olo = cudaq::opt::factory::getIntIfConstant(subveq.getLower())) { - regen = true; - loVal = nullptr; - lo = *olo; - } - - std::int64_t hi = subveq.getConstantUpperBound(); - Value hiVal = subveq.getUpper(); - if (!subveq.hasConstantUpperBound()) - if (auto ohi = cudaq::opt::factory::getIntIfConstant(subveq.getUpper())) { - regen = true; - hiVal = nullptr; - hi = *ohi; - } - - if (!regen) - return failure(); - rewriter.replaceOpWithNewOp( - subveq, subveq.getType(), subveq.getVeq(), loVal, hiVal, lo, hi); - return success(); - } -}; - -// Replace subveq operations that extract the entire original register with the -// original register. -struct RemoveSubVeqNoOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::SubVeqOp subVeqOp, - PatternRewriter &rewriter) const override { - auto origVeq = subVeqOp.getVeq(); - // The original veq size must be known - auto veqType = dyn_cast(origVeq.getType()); - if (!veqType.hasSpecifiedSize()) - return failure(); - if (!(subVeqOp.hasConstantLowerBound() && subVeqOp.hasConstantUpperBound())) - return failure(); - - // If the subveq is the whole register, than the start value must be 0. - if (subVeqOp.getConstantLowerBound() != 0) - return failure(); - - // If the sizes are equal, then replace - if (veqType.getSize() != subVeqOp.getConstantUpperBound() + 1) - return failure(); - - // this subveq is the whole original register, hence a no-op - rewriter.replaceOp(subVeqOp, origVeq); - return success(); - } -}; -} // namespace - void quake::SubVeqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add { - using OpRewritePattern::OpRewritePattern; - - // %11 = quake.init_state %_, %_ : (!quake.veq<2>, T1) -> !quake.veq - // %12 = quake.veq_size %11 : (!quake.veq) -> i64 - // ──────────────────────────────────────────────────────────────────── - // %11 = quake.init_state %_, %_ : (!quake.veq<2>, T1) -> !quake.veq - // %12 = constant 2 : i64 - LogicalResult matchAndRewrite(quake::VeqSizeOp veqSize, - PatternRewriter &rewriter) const override { - Value veq = veqSize.getVeq(); - if (auto initState = veq.getDefiningOp()) - if (auto veqTy = - dyn_cast(initState.getTargets().getType())) - if (veqTy.hasSpecifiedSize()) { - std::size_t numQubits = veqTy.getSize(); - rewriter.replaceOpWithNewOp(veqSize, numQubits, - veqSize.getType()); - return success(); - } - return failure(); - } -}; -} // namespace - void quake::VeqSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add( @@ -856,22 +507,6 @@ void quake::VeqSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // WrapOp //===----------------------------------------------------------------------===// -namespace { -// If there is no operation that modifies the wire after it gets unwrapped and -// before it is wrapped, then the wrap operation is a nop and can be -// eliminated. -struct KillDeadWrapPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(quake::WrapOp wrap, - PatternRewriter &rewriter) const override { - if (auto unwrap = wrap.getWireValue().getDefiningOp()) - rewriter.eraseOp(wrap); - return success(); - } -}; -} // namespace - void quake::WrapOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); @@ -1044,53 +679,6 @@ void quake::RxOp::getOperatorMatrix(Matrix &matrix) { -1i * std::sin(theta / 2.), std::cos(theta / 2.)}); } -namespace { -template -struct MergeRotationPattern : public OpRewritePattern { - using Base = OpRewritePattern; - using Base::Base; - - LogicalResult matchAndRewrite(OP rotate, - PatternRewriter &rewriter) const override { - auto wireTy = quake::WireType::get(rewriter.getContext()); - if (rotate.getTarget(0).getType() != wireTy || - !rotate.getControls().empty()) - return failure(); - assert(!rotate.getNegatedQubitControls()); - auto input = rotate.getTarget(0).template getDefiningOp(); - if (!input || !input.getControls().empty()) - return failure(); - assert(!input.getNegatedQubitControls()); - - // At this point, we have - // %input = quake.rotate %angle1, %wire - // %rotate = quake.rotate %angle2, %input - // Replace those ops with - // %new = quake.rotate (%angle1 + %angle2), %wire - auto loc = rotate.getLoc(); - auto angle1 = input.getParameter(0); - auto angle2 = rotate.getParameter(0); - if (angle1.getType() != angle2.getType()) - return failure(); - auto adjAttr = rotate.getIsAdjAttr(); - auto newAngle = [&]() -> Value { - if (input.isAdj() == rotate.isAdj()) - return rewriter.create(loc, angle1, angle2); - // One is adjoint, so it should be subtracted from the other. - if (input.isAdj()) - return rewriter.create(loc, angle2, angle1); - adjAttr = input.getIsAdjAttr(); - return rewriter.create(loc, angle1, angle2); - }(); - rewriter.replaceOpWithNewOp(rotate, rotate.getResultTypes(), adjAttr, - ValueRange{newAngle}, ValueRange{}, - ValueRange{input.getTarget(0)}, - rotate.getNegatedQubitControlsAttr()); - return success(); - } -}; -} // namespace - void quake::RxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add>(context);