Skip to content

Commit

Permalink
Implement support for pre-visitor callbacks.
Browse files Browse the repository at this point in the history
This change implements a mechanism to run pre-visit callbacks on each
operation it stumbles upon. This can be generalized for every op that is
being visited, for the current nesting level, or for a given `OpSet`.

The code simplifies writing visitors in cases where generic code is written
for a set operations, like resetting the insertion point of an
`IRBuilder`.
  • Loading branch information
tsymalla authored and tsymalla-AMD committed Feb 1, 2024
1 parent daa38d7 commit 9a37c93
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 41 deletions.
43 changes: 39 additions & 4 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ std::unique_ptr<Module> createModuleExample(LLVMContext &context) {

struct VisitorInnermost {
int counter = 0;
raw_ostream *out = nullptr;
};

struct VisitorNest {
Expand All @@ -164,6 +165,13 @@ struct llvm_dialects::VisitorPayloadProjection<VisitorNest, raw_ostream> {
static raw_ostream &project(VisitorNest &nest) { return *nest.out; }
};

template <>
struct llvm_dialects::VisitorPayloadProjection<VisitorInnermost, raw_ostream> {
static raw_ostream &project(VisitorInnermost &innerMost) {
return *innerMost.out;
}
};

LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorContainer, nest)
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner)

Expand Down Expand Up @@ -202,8 +210,8 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
b.addSet(complexSet, [](VisitorNest &self, llvm::Instruction &op) {
assert((op.getOpcode() == Instruction::Ret ||
(isa<IntrinsicInst>(&op) &&
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
Intrinsic::umin)) &&
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
Intrinsic::umin)) &&
"Unexpected operation detected while visiting OpSet!");

if (op.getOpcode() == Instruction::Ret) {
Expand Down Expand Up @@ -236,10 +244,36 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) {
out << "visiting umax intrinsic: " << umax << '\n';
});
b.addPreVisitCallback<xd::ReadOp, xd::WriteOp>(
[](raw_ostream &out, llvm::Instruction &inst) {
if (isa<xd::ReadOp>(inst))
out << "Will visit ReadOp next: " << inst << '\n';
else if (isa<xd::WriteOp>(inst))
out << "Will visit WriteOp next: " << inst << '\n';
else
llvm_unreachable("Unexpected op!");
});

b.addPreVisitCallback([](raw_ostream &out, Instruction &inst) {
if (isa<IntrinsicInst>(inst))
out << "Pre-visiting intrinsic instruction: " << inst << '\n';
});
});
b.nest<VisitorInnermost>([](VisitorBuilder<VisitorInnermost> &b) {
b.add<xd::ITruncOp>([](VisitorInnermost &inner,
xd::ITruncOp &op) { inner.counter++; });
b.add<xd::ITruncOp>(
[](VisitorInnermost &inner, xd::ITruncOp &op) {
inner.counter++;
*inner.out
<< "Counter after visiting ITruncOp: " << inner.counter
<< '\n';
});

b.addPreVisitCallback<xd::ITruncOp>(
[](VisitorInnermost &inner, Instruction &op) {
if (isa<xd::ITruncOp>(op))
*inner.out << "Counter before visiting ITruncOp: "
<< inner.counter << '\n';
});
});
})
.setStrategy(rpot ? VisitorStrategy::ReversePostOrder
Expand All @@ -254,6 +288,7 @@ void exampleVisit(Module &module) {

VisitorContainer container;
container.nest.out = &outs();
container.nest.inner.out = &outs();
visitor.visit(container, module);

outs() << "inner.counter = " << container.nest.inner.counter << '\n';
Expand Down
7 changes: 6 additions & 1 deletion include/llvm-dialects/Dialect/OpSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OpSet final {
// arguments.
template <typename... OpTs> static const OpSet get() {
static OpSet set;
(... && appendT<OpTs>(set));
(void)(... && appendT<OpTs>(set));
return set;
}

Expand Down Expand Up @@ -153,6 +153,11 @@ class OpSet final {
return isMatchingDialectOp(func.getName());
}

bool empty() const {
return m_coreOpcodes.empty() && m_intrinsicIDs.empty() &&
m_dialectOps.empty();
}

// -------------------------------------------------------------
// Convenience getters to access the internal data structures.
// -------------------------------------------------------------
Expand Down
50 changes: 47 additions & 3 deletions include/llvm-dialects/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,27 @@ class VisitorTemplate {
friend class VisitorBuilderBase;

public:
enum class VisitorCallbackType : uint8_t { PreVisit = 0, Visit = 1 };

void setStrategy(VisitorStrategy strategy);
void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data,
VisitorHandler::Projection projection);
VisitorHandler::Projection projection,
VisitorCallbackType visitorCallbackTy = VisitorCallbackType::Visit);

private:
void storeHandlersInOpMap(const VisitorKey &key, unsigned handlerIdx,
VisitorCallbackType callbackTy);

VisitorStrategy m_strategy = VisitorStrategy::Default;
std::vector<PayloadProjection> m_projections;
std::vector<VisitorHandler> m_handlers;
OpMap<llvm::SmallVector<unsigned>> m_opMap;

struct Handlers {
llvm::SmallVector<unsigned> PreVisitHandlers;
llvm::SmallVector<unsigned> VisitHandlers;
};

OpMap<Handlers> m_opMap;
};

/// @brief Base class for VisitorBuilders
Expand Down Expand Up @@ -279,6 +291,9 @@ class VisitorBuilderBase {

void setStrategy(VisitorStrategy strategy);

void addPreVisitCallback(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data);

void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data);

VisitorBase build();
Expand Down Expand Up @@ -307,6 +322,11 @@ class VisitorBase {
class BuildHelper;
using HandlerRange = std::pair<unsigned, unsigned>;

struct MappedHandlers {
HandlerRange PreVisitCallbacks;
HandlerRange VisitCallbacks;
};

void call(HandlerRange handlers, void *payload,
llvm::Instruction &inst) const;
VisitorResult call(const VisitorHandler &handler, void *payload,
Expand All @@ -319,7 +339,7 @@ class VisitorBase {
VisitorStrategy m_strategy;
std::vector<PayloadProjection> m_projections;
std::vector<VisitorHandler> m_handlers;
OpMap<HandlerRange> m_opMap;
OpMap<MappedHandlers> m_opMap;
};

} // namespace detail
Expand Down Expand Up @@ -386,6 +406,20 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
return *this;
}

VisitorBuilder &
addPreVisitCallback(const OpSet &opSet,
VisitorResult (*fn)(PayloadT &, llvm::Instruction &I)) {
addPreVisitCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

template <typename... OpTs>
VisitorBuilder &addPreVisitCallback(void (*fn)(PayloadT &,
llvm::Instruction &I)) {
addPreVisitCase(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }

template <typename OpT>
Expand Down Expand Up @@ -510,6 +544,16 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data);
}

template <typename ReturnT>
void addPreVisitCase(detail::VisitorKey key,
ReturnT (*fn)(PayloadT &, llvm::Instruction &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::addPreVisitCallback(
key, &VisitorBuilder::setForwarder<ReturnT>, data);
}

template <typename OpT, typename ReturnT>
void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) {
detail::VisitorCallbackData data{};
Expand Down
106 changes: 74 additions & 32 deletions lib/Dialect/Visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"

Expand All @@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) {
m_strategy = strategy;
}

void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data,
VisitorHandler::Projection projection) {
VisitorHandler handler;
handler.callback = fn;
handler.data = data;
handler.projection = projection;

m_handlers.emplace_back(handler);
void VisitorTemplate::storeHandlersInOpMap(
const VisitorKey &key, unsigned handlerIdx,
VisitorCallbackType visitorCallbackTy) {
const auto HandlerList =
[&](const OpDescription &opDescription) -> llvm::SmallVector<unsigned> & {
if (visitorCallbackTy == VisitorCallbackType::PreVisit)
return m_opMap[opDescription].PreVisitHandlers;

const unsigned handlerIdx = m_handlers.size() - 1;
return m_opMap[opDescription].VisitHandlers;
};

if (key.m_kind == VisitorKey::Kind::Intrinsic) {
m_opMap[OpDescription::fromIntrinsic(key.m_intrinsicId)].push_back(
handlerIdx);
HandlerList(OpDescription::fromIntrinsic(key.m_intrinsicId))
.push_back(handlerIdx);
} else if (key.m_kind == VisitorKey::Kind::OpDescription) {
const OpDescription *opDesc = key.m_description;

if (opDesc->isCoreOp()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx);
HandlerList(OpDescription::fromCoreOp(op)).push_back(handlerIdx);
} else if (opDesc->isIntrinsic()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx);
HandlerList(OpDescription::fromIntrinsic(op)).push_back(handlerIdx);
} else {
m_opMap[*opDesc].push_back(handlerIdx);
HandlerList(*opDesc).push_back(handlerIdx);
}
} else if (key.m_kind == VisitorKey::Kind::OpSet) {
const OpSet *opSet = key.m_set;

if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty()) {
// This adds a handler for every stored op.
// Note: should be used with caution.
for (auto it : m_opMap)
it.second.PreVisitHandlers.push_back(handlerIdx);

return;
}

for (unsigned opcode : opSet->getCoreOpcodes())
m_opMap[OpDescription::fromCoreOp(opcode)].push_back(handlerIdx);
HandlerList(OpDescription::fromCoreOp(opcode)).push_back(handlerIdx);

for (unsigned intrinsicID : opSet->getIntrinsicIDs())
m_opMap[OpDescription::fromIntrinsic(intrinsicID)].push_back(handlerIdx);
HandlerList(OpDescription::fromIntrinsic(intrinsicID))
.push_back(handlerIdx);

for (const auto &dialectOpPair : opSet->getDialectOps()) {
m_opMap[OpDescription::fromDialectOp(dialectOpPair.isOverload,
dialectOpPair.mnemonic)]
for (const auto &dialectOpPair : opSet->getDialectOps())
HandlerList(OpDescription::fromDialectOp(dialectOpPair.isOverload,
dialectOpPair.mnemonic))
.push_back(handlerIdx);
}
}
}

void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data,
VisitorHandler::Projection projection,
VisitorCallbackType visitorCallbackTy) {
assert(visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set);

VisitorHandler handler;
handler.callback = fn;
handler.data = data;
handler.projection = projection;

m_handlers.emplace_back(handler);

const unsigned handlerIdx = m_handlers.size() - 1;

storeHandlersInOpMap(key, handlerIdx, visitorCallbackTy);
}

VisitorBuilderBase::VisitorBuilderBase() : m_template(&m_ownedTemplate) {}

VisitorBuilderBase::VisitorBuilderBase(VisitorBuilderBase *parent,
Expand Down Expand Up @@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) {
m_template->setStrategy(strategy);
}

void VisitorBuilderBase::addPreVisitCallback(VisitorKey key,
VisitorCallback *fn,
VisitorCallbackData data) {
m_template->add(key, fn, data, m_projection,
VisitorTemplate::VisitorCallbackType::PreVisit);
}

void VisitorBuilderBase::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data) {
m_template->add(key, fn, data, m_projection);
Expand Down Expand Up @@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
BuildHelper helper(*this, templ.m_handlers);

m_opMap.reserve(templ.m_opMap);

for (auto it : templ.m_opMap)
m_opMap[it.first] = helper.mapHandlers(it.second);
for (auto it : templ.m_opMap) {
m_opMap[it.first].PreVisitCallbacks =
helper.mapHandlers(it.second.PreVisitHandlers);
m_opMap[it.first].VisitCallbacks =
helper.mapHandlers(it.second.VisitHandlers);
}
}

void VisitorBase::call(HandlerRange handlers, void *payload,
Expand Down Expand Up @@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
}

void VisitorBase::visit(void *payload, Instruction &inst) const {
auto handlers = m_opMap.find(inst);
if (!handlers)
auto mappedHandlers = m_opMap.find(inst);
if (!mappedHandlers)
return;

call(*handlers.val(), payload, inst);
auto &callbacks = *mappedHandlers.val();

call(callbacks.PreVisitCallbacks, payload, inst);
call(callbacks.VisitCallbacks, payload, inst);
}

template <typename FilterT>
Expand All @@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module,

LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n');

auto handlers = m_opMap.find(decl);
if (!handlers) {
auto mappedHandlers = m_opMap.find(decl);
if (!mappedHandlers) {
// Neither a matched intrinsic nor a matched dialect op; skip.
continue;
}

auto &callbacks = *mappedHandlers.val();

for (Use &use : make_early_inc_range(decl.uses())) {
if (auto *inst = dyn_cast<Instruction>(use.getUser())) {
if (!filter(*inst))
continue;
if (auto *callInst = dyn_cast<CallInst>(inst)) {
if (&use == &callInst->getCalledOperandUse())
call(*handlers.val(), payload, *callInst);
if (&use == &callInst->getCalledOperandUse()) {
call(callbacks.PreVisitCallbacks, payload, *callInst);
call(callbacks.VisitCallbacks, payload, *callInst);
}
}
}
}
Expand Down
Loading

0 comments on commit 9a37c93

Please sign in to comment.