From 8e251c6813a847fe5b7971ef291f6f1794057c1d Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Wed, 11 Dec 2024 13:04:27 -0800 Subject: [PATCH] [core] Move canonical patterns out of tablegen. Fix #476. Eliminate the file Canonical.td and move/rewrite the patterns from that file. Signed-off-by: Eric Schweitz --- .../Optimizer/Dialect/Quake/CMakeLists.txt | 4 - .../Optimizer/Dialect/Quake/Canonical.td | 67 -------------- lib/Optimizer/Dialect/Quake/CMakeLists.txt | 1 - .../Dialect/Quake/CanonicalPatterns.inc | 90 +++++++++++++++++++ lib/Optimizer/Dialect/Quake/QuakeOps.cpp | 10 ++- 5 files changed, 97 insertions(+), 75 deletions(-) delete mode 100644 include/cudaq/Optimizer/Dialect/Quake/Canonical.td create mode 100644 lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc diff --git a/include/cudaq/Optimizer/Dialect/Quake/CMakeLists.txt b/include/cudaq/Optimizer/Dialect/Quake/CMakeLists.txt index d038abd040..6dca96dc83 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/CMakeLists.txt +++ b/include/cudaq/Optimizer/Dialect/Quake/CMakeLists.txt @@ -9,7 +9,3 @@ add_cudaq_dialect(Quake quake) add_cudaq_interface(QuakeInterfaces) add_cudaq_dialect_doc(QuakeDialect quake) - -set(LLVM_TARGET_DEFINITIONS Canonical.td) -mlir_tablegen(Canonical.inc -gen-rewriters) -add_public_tablegen_target(CanonicalIncGen) diff --git a/include/cudaq/Optimizer/Dialect/Quake/Canonical.td b/include/cudaq/Optimizer/Dialect/Quake/Canonical.td deleted file mode 100644 index d7aec89e6f..0000000000 --- a/include/cudaq/Optimizer/Dialect/Quake/Canonical.td +++ /dev/null @@ -1,67 +0,0 @@ -/********************************************************** -*- tablegen -*- *** - * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * - * All rights reserved. * - * * - * This source code and the accompanying materials are made available under * - * the terms of the Apache License 2.0 which accompanies this distribution. * - ******************************************************************************/ - -#ifndef NVQPP_OPTIMIZER_DIALECT_QUAKE_CANONICAL -#define NVQPP_OPTIMIZER_DIALECT_QUAKE_CANONICAL - -include "mlir/IR/OpBase.td" -include "mlir/IR/PatternBase.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" -include "cudaq/Optimizer/Dialect/Quake/QuakeOps.td" - -def KnownSizePred : Constraint< - CPred<"$0.getType().isa() && " - "$0.getType().cast().hasSpecifiedSize()">>; - -def UnknownSizePred : Constraint< - CPred<"$0.getType().isa() && " - "!$0.getType().cast().hasSpecifiedSize()">>; - -def createConstantOp : NativeCodeCall< - "$_builder.create($_loc, $0.getType()," - " $_builder.getIntegerAttr($0.getType()," - " $1.getType().cast().getSize()))">; - -// %4 = quake.veq_size %3 : (!quake.veq<10>) -> 164 -// ──────────────────────────────────────────────── -// %4 = constant 10 : i64 -def ForwardConstantVeqSizePattern : Pat< - (quake_VeqSizeOp:$res $veq), (createConstantOp $res, $veq), - [(KnownSizePred $veq)]>; - -def SizeIsPresentPred : Constraint(" - " $0[0].getDefiningOp())">>; - -def createAllocaOp : NativeCodeCall< - "quake::createConstantAlloca($_builder, $_loc, $0, $1)">; - -// %2 = constant 10 : i32 -// %3 = quake.alloca !quake.veq[%2 : i32] -// ─────────────────────────────────────────── -// %3 = quake.alloca !quake.veq<10> -def FuseConstantToAllocaPattern : Pat< - (quake_AllocaOp:$alloca $optSize), (createAllocaOp $alloca, $optSize), - [(SizeIsPresentPred $optSize)]>; - -def createExtractRefOp : NativeCodeCall< - "$_builder.create($_loc, $0," - " cast($1[0].getDefiningOp()).getValue()." - " cast().getInt())">; - -// %2 = constant 10 : i32 -// %3 = quake.extract_ref %1[%2] : (!quake.veq, i32) -> !quake.ref -// ─────────────────────────────────────────── -// %3 = quake.extract_ref %1[10] : (!quake.veq) -> !quake.ref -def FuseConstantToExtractRefPattern : Pat< - (quake_ExtractRefOp $veq, $index, $rawIndex), - (createExtractRefOp $veq, $index), - [(SizeIsPresentPred $index)]>; - -#endif diff --git a/lib/Optimizer/Dialect/Quake/CMakeLists.txt b/lib/Optimizer/Dialect/Quake/CMakeLists.txt index 87733716c4..bc55e40d52 100644 --- a/lib/Optimizer/Dialect/Quake/CMakeLists.txt +++ b/lib/Optimizer/Dialect/Quake/CMakeLists.txt @@ -16,7 +16,6 @@ add_cudaq_dialect_library(QuakeDialect QuakeDialectIncGen QuakeOpsIncGen QuakeTypesIncGen - CanonicalIncGen LINK_LIBS CCDialect diff --git a/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc b/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc new file mode 100644 index 0000000000..b7ad227685 --- /dev/null +++ b/lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc @@ -0,0 +1,90 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// These canonicalization patterns are used by the canonicalize pass and not +// shared for other uses. Generally speaking, these patterns should be trivial +// peephole optimizations that reduce the size and complexity of the input IR. + +// This file must be included after a `using namespace mlir;` as it uses bare +// identifiers from that namespace. + +namespace { + +// %4 = quake.veq_size %3 : (!quake.veq<10>) -> 164 +// ──────────────────────────────────────────────── +// %4 = constant 10 : i64 +struct ForwardConstantVeqSizePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::VeqSizeOp veqSize, + PatternRewriter &rewriter) const override { + auto veqTy = dyn_cast(veqSize.getVeq().getType()); + if (!veqTy) + return failure(); + if (!veqTy.hasSpecifiedSize()) + return failure(); + auto resTy = veqSize.getType(); + rewriter.replaceOpWithNewOp(veqSize, veqTy.getSize(), + resTy); + return success(); + } +}; + +// %2 = constant 10 : i32 +// %3 = quake.alloca !quake.veq[%2 : i32] +// ─────────────────────────────────────────── +// %3 = quake.alloca !quake.veq<10> +struct FuseConstantToAllocaPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::AllocaOp alloc, + PatternRewriter &rewriter) const override { + auto size = alloc.getSize(); + if (!size) + return failure(); + auto intCon = cudaq::opt::factory::getIntIfConstant(size); + if (!intCon) + return failure(); + auto veqTy = dyn_cast(alloc.getType()); + if (!veqTy) + return failure(); + if (veqTy.hasSpecifiedSize()) + return failure(); + auto loc = alloc.getLoc(); + auto resTy = alloc.getType(); + auto newAlloc = rewriter.create( + loc, static_cast(*intCon)); + rewriter.replaceOpWithNewOp(alloc, resTy, newAlloc); + return success(); + } +}; + +// %2 = constant 10 : i32 +// %3 = quake.extract_ref %1[%2] : (!quake.veq, i32) -> !quake.ref +// ─────────────────────────────────────────── +// %3 = quake.extract_ref %1[10] : (!quake.veq) -> !quake.ref +struct FuseConstantToExtractRefPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::ExtractRefOp extract, + PatternRewriter &rewriter) const override { + auto index = extract.getIndex(); + if (!index) + return failure(); + auto intCon = cudaq::opt::factory::getIntIfConstant(index); + if (!intCon) + return failure(); + rewriter.replaceOpWithNewOp( + extract, extract.getVeq(), static_cast(*intCon)); + return success(); + } +}; + +} // namespace diff --git a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp index 1f2c3bd06d..852e44814b 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeOps.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeOps.cpp @@ -23,9 +23,7 @@ using namespace mlir; -namespace { -#include "cudaq/Optimizer/Dialect/Quake/Canonical.inc" -} // namespace +#include "CanonicalPatterns.inc" static LogicalResult verifyWireResultsAreLinear(Operation *op) { for (Value v : op->getOpResults()) @@ -344,6 +342,12 @@ struct ConcatNoOpPattern : public OpRewritePattern { } }; +// %8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, +// !quake.veq<2>) -> !quake.veq +// ─────────────────────────────────────────────────────────── +// %.8 = quake.concat %4, %5, %6 : (!quake.ref, !quake.veq<4>, +// !quake.veq<2>) -> !quake.veq<7> +// %8 = quake.relax_size %.8 : (!quake.veq<7>) -> !quake.veq struct ConcatSizePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern;