Skip to content

Commit

Permalink
WIP: tablegen: add hasTypeInfo field to DialectType
Browse files Browse the repository at this point in the history
TODO:
- align with the required upstream changes

This allows users to hook into the (proposed) TargetExtTypeClass
infrastructure to set the properties of dialect types.

See: https://discourse.llvm.org/t/rfc-target-type-classes-for-extensibility-of-llvm-ir/69813
See: https://reviews.llvm.org/D147697
  • Loading branch information
nhaehnle committed May 11, 2023
1 parent c72c998 commit d3c2785
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 32 deletions.
12 changes: 12 additions & 0 deletions example/ExampleDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &out, VectorKind x) {

#define GET_DIALECT_DEFS
#include "ExampleDialect.cpp.inc"

using namespace llvm;
using namespace xd;

void XdVectorType::customizeTypeClass(TargetExtTypeClass *typeClass) {
typeClass->setGetLayoutType([](TargetExtType *type) -> Type * {
auto *vt = cast<XdVectorType>(type);
if (vt->getKind() == VectorKind::MiddleEndian)
return Type::getVoidTy(vt->getContext());
return FixedVectorType::get(vt->getElementType(), vt->getNumElements());
});
}
2 changes: 2 additions & 0 deletions example/ExampleDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def XdVectorType : DialectType<ExampleDialect, "vector"> {
let typeArguments = (args AttrVectorKind:$kind, type:$element_type,
AttrI32:$num_elements);

let customizeTypeClass = true;

let summary = "a custom vector type";
let description = [{
Unlike LLVM's built-in vector type, this vector can have arbitrary element
Expand Down
11 changes: 9 additions & 2 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ void createFunctionExample(Module &module, const Twine &name) {
BasicBlock *bb = BasicBlock::Create(module.getContext(), "entry", fn);
b.SetInsertPoint(bb);

Type *vt1 =
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 4);

Value *alloca1 = b.CreateAlloca(vt1);

Value *x1 = b.create<xd::ReadOp>(b.getInt32Ty());
Value *sizeOf = b.create<xd::SizeOfOp>(b.getHalfTy());
Value *sizeOf32 = b.create<xd::ITruncOp>(b.getInt32Ty(), sizeOf);
Expand All @@ -103,8 +108,10 @@ void createFunctionExample(Module &module, const Twine &name) {

Value *y1 = b.create<xd::ReadOp>(
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 4));
Value *y2 = b.create<xd::ExtractElementOp>(y1, x1);
Value *y3 = b.create<xd::ExtractElementOp>(y1, b.getInt32(2));
b.CreateStore(y1, alloca1);
Value *y1l = b.CreateLoad(vt1, alloca1);
Value *y2 = b.create<xd::ExtractElementOp>(y1l, x1);
Value *y3 = b.create<xd::ExtractElementOp>(y1l, b.getInt32(2));
Value *y4 = b.CreateAdd(y2, y3);
Value *y5 = b.create<xd::InsertElementOp>(q2, y4, x1);
auto *y6 = b.create<xd::InsertElementOp>(y5, y2, b.getInt32(5));
Expand Down
11 changes: 10 additions & 1 deletion include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,19 @@ class DialectType<Dialect dialect_, string mnemonic_> : Type, Predicate {
Dialect dialect = dialect_;
string mnemonic = mnemonic_;

/// Whether the Type::get method has an explicit LLVMContext reference as the
/// Whether the $Type::get method has an explicit LLVMContext reference as the
/// first argument.
bit defaultGetterHasExplicitContextArgument = false;

/// If set to true, a method of signature
///
/// static void customizeTypeClass(llvm::TargetExtTypeClass *typeClass);
///
/// is declared, to be defined manually by the user. The function is called
/// only once and may adjust the type class. The type class is pre-populated
/// with the type name and the default verifier.
bit customizeTypeClass = false;

string summary = ?;
string description = ?;
}
Expand Down
4 changes: 4 additions & 0 deletions include/llvm-dialects/TableGen/DialectType.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ class DialectType : public BaseCppPredicate {
bool defaultGetterHasExplicitContextArgument() const {
return m_defaultGetterHasExplicitContextArgument;
}
bool customizeTypeClass() const { return m_customizeTypeClass; }
llvm::StringRef getSummary() const { return m_summary; }
llvm::StringRef getDescription() const { return m_description; }

void emitTypeClass(llvm::raw_ostream &out, GenDialect *dialect,
FmtContext &fmt) const;
void emitDeclaration(llvm::raw_ostream &out, GenDialect *dialect) const;
void emitDefinition(llvm::raw_ostream &out, GenDialect *dialect) const;

Expand All @@ -62,6 +65,7 @@ class DialectType : public BaseCppPredicate {
std::string m_name;
std::string m_mnemonic;
bool m_defaultGetterHasExplicitContextArgument = false;
bool m_customizeTypeClass = false;
std::string m_summary;
std::string m_description;

Expand Down
37 changes: 37 additions & 0 deletions lib/TableGen/DialectType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context,
m_defaultGetterHasExplicitContextArgument =
record->getValueAsBit("defaultGetterHasExplicitContextArgument");

m_customizeTypeClass = record->getValueAsBit("customizeTypeClass");

FmtContext fmt;
fmt.addSubst("_type", m_name);
{
Expand Down Expand Up @@ -141,6 +143,35 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context,
return true;
}

void DialectType::emitTypeClass(llvm::raw_ostream &out, GenDialect *dialect,
FmtContext &fmt) const {
FmtContextScope scope(fmt);
fmt.addSubst("dialect", dialect->name);
fmt.addSubst("_type", getName());
fmt.addSubst("mnemonic", getMnemonic());

if (m_customizeTypeClass) {
fmt.addSubst("customize",
tgfmt("$_type::customizeTypeClass(&theClass);", &fmt).str());
} else {
fmt.addSubst("customize", "");
}

out << tgfmt(R"(
static const auto class$_type = ([]() {
::llvm::TargetExtTypeClass theClass("$dialect.$mnemonic");
theClass.setVerifier(
[](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) {
return ::llvm::cast<$_type>(T)->verifier(errs);
});
$customize
return theClass;
})();
$_context.registerTargetExtTypeClass(&class$_type);
)",
&fmt);
}

void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
FmtContext fmt;
fmt.withContext(m_context);
Expand Down Expand Up @@ -174,6 +205,12 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
}
out << ");\n\n";

if (m_customizeTypeClass) {
out << R"(
static void customizeTypeClass(::llvm::TargetExtTypeClass *typeClass);
)";
}

for (const auto &argument : typeArguments()) {
out << tgfmt("$0 get$1() const;\n", &fmt, argument.type->getCppType(),
convertToCamelFromSnakeCase(argument.name, true));
Expand Down
23 changes: 23 additions & 0 deletions lib/TableGen/GenDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ void llvm_dialects::genDialectDecls(raw_ostream& out, RecordKeeper& records) {
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"
)";

if (!dialect->types.empty()) {
out << R"(
namespace llvm {
class TargetExtTypeClass;
} // namespace llvm
)";
}

out << R"(
namespace llvm {
class raw_ostream;
} // namespace llvm
Expand Down Expand Up @@ -236,6 +247,12 @@ void llvm_dialects::genDialectDefs(raw_ostream& out, RecordKeeper& records) {
)";
}

if (!dialect->types.empty()) {
out << R"(
#include "llvm/IR/TargetExtType.h"
)";
}

out << R"(
#include "llvm/Support/raw_ostream.h"
#endif // GET_INCLUDES
Expand Down Expand Up @@ -322,6 +339,12 @@ void llvm_dialects::genDialectDefs(raw_ostream& out, RecordKeeper& records) {
}
}

FmtContextScope scope{fmt};
fmt.withContext("context");

for (DialectType *type : dialect->types)
type->emitTypeClass(out, dialect, fmt);

out << "}\n\n";

// Type class definitions.
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#
#######################################################################################################################

set(LLVM_DIALECTS_TEST_DEPENDS FileCheck count not llvm-dialects-example)
set(LLVM_DIALECTS_TEST_DEPENDS FileCheck count not split-file llvm-dialects-example)
add_custom_target(llvm-dialects-test-depends DEPENDS ${LLVM_DIALECTS_TEST_DEPENDS})
set_target_properties(llvm-dialects-test-depends PROPERTIES FOLDER "Tests")

Expand Down
26 changes: 25 additions & 1 deletion test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "llvm/Support/ModRef.h"

#include "llvm/IR/TargetExtType.h"

#include "llvm/Support/raw_ostream.h"
#endif // GET_INCLUDES

Expand Down Expand Up @@ -138,7 +140,29 @@ attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::ModRef));
m_attributeLists[3] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
}

static const auto classXdHandleType = ([]() {
::llvm::TargetExtTypeClass theClass("xd.handle");
theClass.setVerifier(
[](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) {
return ::llvm::cast<XdHandleType>(T)->verifier(errs);
});

return theClass;
})();
context.registerTargetExtTypeClass(&classXdHandleType);

static const auto classXdVectorType = ([]() {
::llvm::TargetExtTypeClass theClass("xd.vector");
theClass.setVerifier(
[](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) {
return ::llvm::cast<XdVectorType>(T)->verifier(errs);
});
XdVectorType::customizeTypeClass(&theClass);
return theClass;
})();
context.registerTargetExtTypeClass(&classXdVectorType);
}

XdHandleType* XdHandleType::get(::llvm::LLVMContext & ctx) {
::std::array<::llvm::Type *, 0> types = {
Expand Down
9 changes: 8 additions & 1 deletion test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"


namespace llvm {
class TargetExtTypeClass;
} // namespace llvm

namespace llvm {
class raw_ostream;
} // namespace llvm
Expand Down Expand Up @@ -80,7 +85,9 @@ namespace xd {

static XdVectorType *get(VectorKind kind, ::llvm::Type * elementType, uint32_t numElements);

VectorKind getKind() const;

static void customizeTypeClass(::llvm::TargetExtTypeClass *typeClass);
VectorKind getKind() const;
::llvm::Type * getElementType() const;
uint32_t getNumElements() const;
};
Expand Down
55 changes: 29 additions & 26 deletions test/example/test-builder.test
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,35 @@

; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read.i32()
; CHECK-NEXT: [[TMP1:%.*]] = call i64 (...) @xd.sizeof(double poison)
; CHECK-NEXT: [[TMP2:%.*]] = call i32 (...) @xd.itrunc.i32(i64 [[TMP1]])
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @xd.add32(i32 [[TMP0]], i32 [[TMP2]], i32 7)
; CHECK-NEXT: [[TMP4:%.*]] = call i32 (...) @xd.combine.i32(i32 [[TMP3]], i32 [[TMP0]])
; CHECK-NEXT: [[TMP5:%.*]] = call i64 (...) @xd.iext.i64(i32 [[TMP4]])
; CHECK-NEXT: call void (...) @xd.write(i64 [[TMP5]])
; CHECK-NEXT: [[TMP6:%.*]] = call <2 x i32> @xd.read.v2i32()
; CHECK-NEXT: [[TMP7:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.fromfixedvector.txd.vector_i32_1_2t(<2 x i32> [[TMP6]])
; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 1, 4) @xd.read.txd.vector_i32_1_4t()
; CHECK-NEXT: [[TMP9:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 [[TMP0]])
; CHECK-NEXT: [[TMP10:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 2)
; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP9]], [[TMP10]]
; CHECK-NEXT: [[TMP12:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP7]], i32 [[TMP11]], i32 [[TMP0]])
; CHECK-NEXT: [[TMP13:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP12]], i32 [[TMP9]], i32 1)
; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 1, 2) [[TMP13]])
; CHECK-NEXT: [[TMP14:%.*]] = call ptr @xd.read.p0()
; CHECK-NEXT: [[TMP15:%.*]] = call i8 (...) @xd.stream.add.i8(ptr [[TMP14]], i64 14, i8 0)
; CHECK-NEXT: call void (...) @xd.write(i8 [[TMP15]])
; CHECK-NEXT: [[TMP16:%.*]] = call target("xd.handle") @xd.handle.get()
; CHECK-NEXT: [[TMP17:%.*]] = call [[TMP0]] @xd.read.s_s()
; CHECK-NEXT: [[TMP18:%.*]] = call [[TMP1]] @xd.read.s_s_0()
; CHECK-NEXT: [[TMP19:%.*]] = call [[TMP2]] @xd.read.s_s_1()
; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP17]])
; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP18]])
; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP19]])
; CHECK-NEXT: [[TMP0:%.*]] = alloca target("xd.vector", i32, 1, 4), align 16
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @xd.read.i32()
; CHECK-NEXT: [[TMP2:%.*]] = call i64 (...) @xd.sizeof(double poison)
; CHECK-NEXT: [[TMP3:%.*]] = call i32 (...) @xd.itrunc.i32(i64 [[TMP2]])
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @xd.add32(i32 [[TMP1]], i32 [[TMP3]], i32 7)
; CHECK-NEXT: [[TMP5:%.*]] = call i32 (...) @xd.combine.i32(i32 [[TMP4]], i32 [[TMP1]])
; CHECK-NEXT: [[TMP6:%.*]] = call i64 (...) @xd.iext.i64(i32 [[TMP5]])
; CHECK-NEXT: call void (...) @xd.write(i64 [[TMP6]])
; CHECK-NEXT: [[TMP7:%.*]] = call <2 x i32> @xd.read.v2i32()
; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.fromfixedvector.txd.vector_i32_1_2t(<2 x i32> [[TMP7]])
; CHECK-NEXT: [[TMP9:%.*]] = call target("xd.vector", i32, 1, 4) @xd.read.txd.vector_i32_1_4t()
; CHECK-NEXT: store target("xd.vector", i32, 1, 4) [[TMP9]], ptr [[TMP0]], align 16
; CHECK-NEXT: [[TMP10:%.*]] = load target("xd.vector", i32, 1, 4), ptr [[TMP0]], align 16
; CHECK-NEXT: [[TMP11:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP10]], i32 [[TMP1]])
; CHECK-NEXT: [[TMP12:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP10]], i32 2)
; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP11]], [[TMP12]]
; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP8]], i32 [[TMP13]], i32 [[TMP1]])
; CHECK-NEXT: [[TMP15:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP14]], i32 [[TMP11]], i32 1)
; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 1, 2) [[TMP15]])
; CHECK-NEXT: [[TMP16:%.*]] = call ptr @xd.read.p0()
; CHECK-NEXT: [[TMP17:%.*]] = call i8 (...) @xd.stream.add.i8(ptr [[TMP16]], i64 14, i8 0)
; CHECK-NEXT: call void (...) @xd.write(i8 [[TMP17]])
; CHECK-NEXT: [[TMP18:%.*]] = call target("xd.handle") @xd.handle.get()
; CHECK-NEXT: [[TMP19:%.*]] = call [[TMP0]] @xd.read.s_s()
; CHECK-NEXT: [[TMP20:%.*]] = call [[TMP1]] @xd.read.s_s_0()
; CHECK-NEXT: [[TMP21:%.*]] = call [[TMP2]] @xd.read.s_s_1()
; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP19]])
; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP20]])
; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP21]])
; CHECK-NEXT: ret void
;
;
Expand Down
18 changes: 18 additions & 0 deletions test/example/verifier-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: split-file %s %t
; RUN: not llvm-dialects-example -verify %t/bad-parameters.ll 2>&1 | FileCheck --check-prefixes=CHECK %t/bad-parameters.ll
; RUN: not llvm-dialects-example -verify %t/bad-type-info.ll 2>&1 | FileCheck --check-prefixes=CHECK %t/bad-type-info.ll

;--- bad-parameters.ll

; CHECK: [[@LINE+4]]:35: error: target type failed validation:
; CHECK: wrong number of int parameters
; CHECK: expected: 2
; CHECK: actual: 3
declare void @test_bad_parameters(target("xd.vector", i32, 1, 4, 5))

;--- bad-type-info.ll

; CHECK: [[@LINE+1]]:1: error: target type has wrong layout type
type target("xd.vector", i32, 1, 2) {
layout: type <4 x i32>,
}

0 comments on commit d3c2785

Please sign in to comment.