diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index ca01caa..a372248 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -152,6 +152,7 @@ std::unique_ptr createModuleExample(LLVMContext &context) { struct VisitorInnermost { int counter = 0; + raw_ostream *out = nullptr; }; struct VisitorNest { @@ -177,6 +178,13 @@ struct llvm_dialects::VisitorPayloadProjection { static raw_ostream &project(VisitorNest &nest) { return *nest.out; } }; +template <> +struct llvm_dialects::VisitorPayloadProjection { + static raw_ostream &project(VisitorInnermost &innerMost) { + return *innerMost.out; + } +}; + LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorContainer, nest) LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner) @@ -215,8 +223,8 @@ template const Visitor &getExampleVisitor() { b.addSet(complexSet, [](VisitorNest &self, llvm::Instruction &op) { assert((op.getOpcode() == Instruction::Ret || (isa(&op) && - cast(&op)->getIntrinsicID() == - Intrinsic::umin)) && + cast(&op)->getIntrinsicID() == + Intrinsic::umin)) && "Unexpected operation detected while visiting OpSet!"); if (op.getOpcode() == Instruction::Ret) { @@ -249,10 +257,36 @@ template const Visitor &getExampleVisitor() { Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) { out << "visiting umax intrinsic: " << umax << '\n'; }); + b.addPreVisitCallback( + [](raw_ostream &out, llvm::Instruction &inst) { + if (isa(inst)) + out << "Will visit ReadOp next: " << inst << '\n'; + else if (isa(inst)) + out << "Will visit WriteOp next: " << inst << '\n'; + else + llvm_unreachable("Unexpected op!"); + }); + + b.addPreVisitCallback([](raw_ostream &out, Instruction &inst) { + if (isa(inst)) + out << "Pre-visiting intrinsic instruction: " << inst << '\n'; + }); }); b.nest([](VisitorBuilder &b) { - b.add([](VisitorInnermost &inner, - xd::ITruncOp &op) { inner.counter++; }); + b.add( + [](VisitorInnermost &inner, xd::ITruncOp &op) { + inner.counter++; + *inner.out + << "Counter after visiting ITruncOp: " << inner.counter + << '\n'; + }); + + b.addPreVisitCallback( + [](VisitorInnermost &inner, Instruction &op) { + if (isa(op)) + *inner.out << "Counter before visiting ITruncOp: " + << inner.counter << '\n'; + }); }); }) .setStrategy(rpot ? VisitorStrategy::ReversePostOrder @@ -267,6 +301,7 @@ void exampleVisit(Module &module) { VisitorContainer container; container.nest.out = &outs(); + container.nest.inner.out = &outs(); visitor.visit(container, module); outs() << "inner.counter = " << container.nest.inner.counter << '\n'; diff --git a/include/llvm-dialects/Dialect/OpSet.h b/include/llvm-dialects/Dialect/OpSet.h index 6f39599..230779a 100644 --- a/include/llvm-dialects/Dialect/OpSet.h +++ b/include/llvm-dialects/Dialect/OpSet.h @@ -91,7 +91,7 @@ class OpSet final { // arguments. template static const OpSet get() { static OpSet set; - (... && appendT(set)); + (void)(... && appendT(set)); return set; } @@ -153,6 +153,11 @@ class OpSet final { return isMatchingDialectOp(func.getName()); } + bool empty() const { + return m_coreOpcodes.empty() && m_intrinsicIDs.empty() && + m_dialectOps.empty(); + } + // ------------------------------------------------------------- // Convenience getters to access the internal data structures. // ------------------------------------------------------------- diff --git a/include/llvm-dialects/Dialect/Visitor.h b/include/llvm-dialects/Dialect/Visitor.h index bce34ca..bb81f61 100644 --- a/include/llvm-dialects/Dialect/Visitor.h +++ b/include/llvm-dialects/Dialect/Visitor.h @@ -238,15 +238,27 @@ class VisitorTemplate { friend class VisitorBuilderBase; public: + enum class VisitorCallbackType : uint8_t { PreVisit = 0, Visit = 1 }; + void setStrategy(VisitorStrategy strategy); void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data, - VisitorHandler::Projection projection); + VisitorHandler::Projection projection, + VisitorCallbackType visitorCallbackTy = VisitorCallbackType::Visit); private: + void storeHandlersInOpMap(const VisitorKey &key, unsigned handlerIdx, + VisitorCallbackType callbackTy); + VisitorStrategy m_strategy = VisitorStrategy::Default; std::vector m_projections; std::vector m_handlers; - OpMap> m_opMap; + + struct Handlers { + llvm::SmallVector PreVisitHandlers; + llvm::SmallVector VisitHandlers; + }; + + OpMap m_opMap; }; /// @brief Base class for VisitorBuilders @@ -279,6 +291,9 @@ class VisitorBuilderBase { void setStrategy(VisitorStrategy strategy); + void addPreVisitCallback(VisitorKey key, VisitorCallback *fn, + VisitorCallbackData data); + void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data); VisitorBase build(); @@ -307,6 +322,11 @@ class VisitorBase { class BuildHelper; using HandlerRange = std::pair; + struct MappedHandlers { + HandlerRange PreVisitCallbacks; + HandlerRange VisitCallbacks; + }; + void call(HandlerRange handlers, void *payload, llvm::Instruction &inst) const; VisitorResult call(const VisitorHandler &handler, void *payload, @@ -319,7 +339,7 @@ class VisitorBase { VisitorStrategy m_strategy; std::vector m_projections; std::vector m_handlers; - OpMap m_opMap; + OpMap m_opMap; }; } // namespace detail @@ -386,6 +406,20 @@ class VisitorBuilder : private detail::VisitorBuilderBase { return *this; } + VisitorBuilder & + addPreVisitCallback(const OpSet &opSet, + VisitorResult (*fn)(PayloadT &, llvm::Instruction &I)) { + addPreVisitCase(detail::VisitorKey::opSet(opSet), fn); + return *this; + } + + template + VisitorBuilder &addPreVisitCallback(void (*fn)(PayloadT &, + llvm::Instruction &I)) { + addPreVisitCase(detail::VisitorKey::opSet(), fn); + return *this; + } + Visitor build() { return VisitorBuilderBase::build(); } template @@ -510,6 +544,16 @@ class VisitorBuilder : private detail::VisitorBuilderBase { VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data); } + template + void addPreVisitCase(detail::VisitorKey key, + ReturnT (*fn)(PayloadT &, llvm::Instruction &)) { + detail::VisitorCallbackData data{}; + static_assert(sizeof(fn) <= sizeof(data.data)); + memcpy(&data.data, &fn, sizeof(fn)); + VisitorBuilderBase::addPreVisitCallback( + key, &VisitorBuilder::setForwarder, data); + } + template void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) { detail::VisitorCallbackData data{}; diff --git a/lib/Dialect/Visitor.cpp b/lib/Dialect/Visitor.cpp index 69e9e4f..d3f3983 100644 --- a/lib/Dialect/Visitor.cpp +++ b/lib/Dialect/Visitor.cpp @@ -19,7 +19,6 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" @@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) { m_strategy = strategy; } -void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn, - VisitorCallbackData data, - VisitorHandler::Projection projection) { - VisitorHandler handler; - handler.callback = fn; - handler.data = data; - handler.projection = projection; - - m_handlers.emplace_back(handler); +void VisitorTemplate::storeHandlersInOpMap( + const VisitorKey &key, unsigned handlerIdx, + VisitorCallbackType visitorCallbackTy) { + const auto HandlerList = + [&](const OpDescription &opDescription) -> llvm::SmallVector & { + if (visitorCallbackTy == VisitorCallbackType::PreVisit) + return m_opMap[opDescription].PreVisitHandlers; - const unsigned handlerIdx = m_handlers.size() - 1; + return m_opMap[opDescription].VisitHandlers; + }; if (key.m_kind == VisitorKey::Kind::Intrinsic) { - m_opMap[OpDescription::fromIntrinsic(key.m_intrinsicId)].push_back( - handlerIdx); + HandlerList(OpDescription::fromIntrinsic(key.m_intrinsicId)) + .push_back(handlerIdx); } else if (key.m_kind == VisitorKey::Kind::OpDescription) { const OpDescription *opDesc = key.m_description; if (opDesc->isCoreOp()) { for (const unsigned op : opDesc->getOpcodes()) - m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx); + HandlerList(OpDescription::fromCoreOp(op)).push_back(handlerIdx); } else if (opDesc->isIntrinsic()) { for (const unsigned op : opDesc->getOpcodes()) - m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx); + HandlerList(OpDescription::fromIntrinsic(op)).push_back(handlerIdx); } else { - m_opMap[*opDesc].push_back(handlerIdx); + HandlerList(*opDesc).push_back(handlerIdx); } } else if (key.m_kind == VisitorKey::Kind::OpSet) { const OpSet *opSet = key.m_set; + if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty()) { + // This adds a handler for every stored op. + // Note: should be used with caution. + for (auto it : m_opMap) + it.second.PreVisitHandlers.push_back(handlerIdx); + + return; + } + for (unsigned opcode : opSet->getCoreOpcodes()) - m_opMap[OpDescription::fromCoreOp(opcode)].push_back(handlerIdx); + HandlerList(OpDescription::fromCoreOp(opcode)).push_back(handlerIdx); for (unsigned intrinsicID : opSet->getIntrinsicIDs()) - m_opMap[OpDescription::fromIntrinsic(intrinsicID)].push_back(handlerIdx); + HandlerList(OpDescription::fromIntrinsic(intrinsicID)) + .push_back(handlerIdx); - for (const auto &dialectOpPair : opSet->getDialectOps()) { - m_opMap[OpDescription::fromDialectOp(dialectOpPair.isOverload, - dialectOpPair.mnemonic)] + for (const auto &dialectOpPair : opSet->getDialectOps()) + HandlerList(OpDescription::fromDialectOp(dialectOpPair.isOverload, + dialectOpPair.mnemonic)) .push_back(handlerIdx); - } } } +void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn, + VisitorCallbackData data, + VisitorHandler::Projection projection, + VisitorCallbackType visitorCallbackTy) { + assert(visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set); + + VisitorHandler handler; + handler.callback = fn; + handler.data = data; + handler.projection = projection; + + m_handlers.emplace_back(handler); + + const unsigned handlerIdx = m_handlers.size() - 1; + + storeHandlersInOpMap(key, handlerIdx, visitorCallbackTy); +} + VisitorBuilderBase::VisitorBuilderBase() : m_template(&m_ownedTemplate) {} VisitorBuilderBase::VisitorBuilderBase(VisitorBuilderBase *parent, @@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) { m_template->setStrategy(strategy); } +void VisitorBuilderBase::addPreVisitCallback(VisitorKey key, + VisitorCallback *fn, + VisitorCallbackData data) { + m_template->add(key, fn, data, m_projection, + VisitorTemplate::VisitorCallbackType::PreVisit); +} + void VisitorBuilderBase::add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data) { m_template->add(key, fn, data, m_projection); @@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ) BuildHelper helper(*this, templ.m_handlers); m_opMap.reserve(templ.m_opMap); - - for (auto it : templ.m_opMap) - m_opMap[it.first] = helper.mapHandlers(it.second); + for (auto it : templ.m_opMap) { + m_opMap[it.first].PreVisitCallbacks = + helper.mapHandlers(it.second.PreVisitHandlers); + m_opMap[it.first].VisitCallbacks = + helper.mapHandlers(it.second.VisitHandlers); + } } void VisitorBase::call(HandlerRange handlers, void *payload, @@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload, } void VisitorBase::visit(void *payload, Instruction &inst) const { - auto handlers = m_opMap.find(inst); - if (!handlers) + auto mappedHandlers = m_opMap.find(inst); + if (!mappedHandlers) return; - call(*handlers.val(), payload, inst); + auto &callbacks = *mappedHandlers.val(); + + call(callbacks.PreVisitCallbacks, payload, inst); + call(callbacks.VisitCallbacks, payload, inst); } template @@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module, LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n'); - auto handlers = m_opMap.find(decl); - if (!handlers) { + auto mappedHandlers = m_opMap.find(decl); + if (!mappedHandlers) { // Neither a matched intrinsic nor a matched dialect op; skip. continue; } + auto &callbacks = *mappedHandlers.val(); + for (Use &use : make_early_inc_range(decl.uses())) { if (auto *inst = dyn_cast(use.getUser())) { if (!filter(*inst)) continue; if (auto *callInst = dyn_cast(inst)) { - if (&use == &callInst->getCalledOperandUse()) - call(*handlers.val(), payload, *callInst); + if (&use == &callInst->getCalledOperandUse()) { + call(callbacks.PreVisitCallbacks, payload, *callInst); + call(callbacks.VisitCallbacks, payload, *callInst); + } } } } diff --git a/test/example/visitor-basic.ll b/test/example/visitor-basic.ll index f173888..f595ee3 100644 --- a/test/example/visitor-basic.ll +++ b/test/example/visitor-basic.ll @@ -1,10 +1,15 @@ ; RUN: llvm-dialects-example -visit %s | FileCheck --check-prefixes=DEFAULT %s -; DEFAULT: visiting ReadOp: %v = call i32 @xd.read__i32() +; DEFAULT: Will visit ReadOp next: %v = call i32 @xd.read__i32() +; DEFAULT-NEXT: visiting ReadOp: %v = call i32 @xd.read__i32() ; DEFAULT-NEXT: visiting UnaryInstruction (pre): %w = load i32, ptr %p ; DEFAULT-NEXT: visiting UnaryInstruction (pre): %q = load i32, ptr %p1 ; DEFAULT-NEXT: visiting BinaryOperator: %v1 = add i32 %v, %w +; DEFAULT-NEXT: Pre-visiting intrinsic instruction: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting umax intrinsic: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) +; DEFAULT-NEXT: Counter before visiting ITruncOp: 0 +; DEFAULT-NEXT: Counter after visiting ITruncOp: 1 +; DEFAULT-NEXT: Will visit WriteOp next: call void (...) @xd.write(i8 %t) ; DEFAULT-NEXT: visiting WriteOp: call void (...) @xd.write(i8 %t) ; DEFAULT-NEXT: visiting SetReadOp: %v.0 = call i1 @xd.set.read__i1() ; DEFAULT-NEXT: visiting SetReadOp: %v.1 = call i32 @xd.set.read__i32() @@ -15,6 +20,7 @@ ; DEFAULT-NEXT: visiting WriteVarArgOp: call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q) ; DEFAULT-NEXT: %v2 = ; DEFAULT-NEXT: %q = +; DEFAULT-NEXT: Pre-visiting intrinsic instruction: %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting Ret (set): ret void ; DEFAULT-NEXT: visiting ReturnInst: ret void