diff --git a/example/ExampleDialect.cpp b/example/ExampleDialect.cpp index cf84460..1b69d0f 100644 --- a/example/ExampleDialect.cpp +++ b/example/ExampleDialect.cpp @@ -52,3 +52,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &out, VectorKind x) { #define GET_DIALECT_DEFS #include "ExampleDialect.cpp.inc" + +using namespace llvm; +using namespace xd; + +void XdVectorType::customizeTypeClass(TargetExtTypeClass *typeClass) { + typeClass->setGetLayoutType([](TargetExtType *type) -> Type * { + auto *vt = cast(type); + if (vt->getKind() == VectorKind::MiddleEndian) + return Type::getVoidTy(vt->getContext()); + return FixedVectorType::get(vt->getElementType(), vt->getNumElements()); + }); +} diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index fb0c840..cef0e2b 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -44,6 +44,8 @@ def XdVectorType : DialectType { let typeArguments = (args AttrVectorKind:$kind, type:$element_type, AttrI32:$num_elements); + let customizeTypeClass = true; + let summary = "a custom vector type"; let description = [{ Unlike LLVM's built-in vector type, this vector can have arbitrary element diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index 7844b16..2b1c5b9 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -86,6 +86,11 @@ void createFunctionExample(Module &module, const Twine &name) { BasicBlock *bb = BasicBlock::Create(module.getContext(), "entry", fn); b.SetInsertPoint(bb); + Type *vt1 = + xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 4); + + Value *alloca1 = b.CreateAlloca(vt1); + Value *x1 = b.create(b.getInt32Ty()); Value *sizeOf = b.create(b.getHalfTy()); Value *sizeOf32 = b.create(b.getInt32Ty(), sizeOf); @@ -103,8 +108,10 @@ void createFunctionExample(Module &module, const Twine &name) { Value *y1 = b.create( xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 4)); - Value *y2 = b.create(y1, x1); - Value *y3 = b.create(y1, b.getInt32(2)); + b.CreateStore(y1, alloca1); + Value *y1l = b.CreateLoad(vt1, alloca1); + Value *y2 = b.create(y1l, x1); + Value *y3 = b.create(y1l, b.getInt32(2)); Value *y4 = b.CreateAdd(y2, y3); Value *y5 = b.create(q2, y4, x1); auto *y6 = b.create(y5, y2, b.getInt32(5)); diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index 67a9277..78d9234 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -194,10 +194,19 @@ class DialectType : Type, Predicate { Dialect dialect = dialect_; string mnemonic = mnemonic_; - /// Whether the Type::get method has an explicit LLVMContext reference as the + /// Whether the $Type::get method has an explicit LLVMContext reference as the /// first argument. bit defaultGetterHasExplicitContextArgument = false; + /// If set to true, a method of signature + /// + /// static void customizeTypeClass(llvm::TargetExtTypeClass *typeClass); + /// + /// is declared, to be defined manually by the user. The function is called + /// only once and may adjust the type class. The type class is pre-populated + /// with the type name and the default verifier. + bit customizeTypeClass = false; + string summary = ?; string description = ?; } diff --git a/include/llvm-dialects/TableGen/DialectType.h b/include/llvm-dialects/TableGen/DialectType.h index 94a189a..6663775 100644 --- a/include/llvm-dialects/TableGen/DialectType.h +++ b/include/llvm-dialects/TableGen/DialectType.h @@ -46,9 +46,12 @@ class DialectType : public BaseCppPredicate { bool defaultGetterHasExplicitContextArgument() const { return m_defaultGetterHasExplicitContextArgument; } + bool customizeTypeClass() const { return m_customizeTypeClass; } llvm::StringRef getSummary() const { return m_summary; } llvm::StringRef getDescription() const { return m_description; } + void emitTypeClass(llvm::raw_ostream &out, GenDialect *dialect, + FmtContext &fmt) const; void emitDeclaration(llvm::raw_ostream &out, GenDialect *dialect) const; void emitDefinition(llvm::raw_ostream &out, GenDialect *dialect) const; @@ -62,6 +65,7 @@ class DialectType : public BaseCppPredicate { std::string m_name; std::string m_mnemonic; bool m_defaultGetterHasExplicitContextArgument = false; + bool m_customizeTypeClass = false; std::string m_summary; std::string m_description; diff --git a/lib/TableGen/DialectType.cpp b/lib/TableGen/DialectType.cpp index 48b41c9..b961c45 100644 --- a/lib/TableGen/DialectType.cpp +++ b/lib/TableGen/DialectType.cpp @@ -49,6 +49,8 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context, m_defaultGetterHasExplicitContextArgument = record->getValueAsBit("defaultGetterHasExplicitContextArgument"); + m_customizeTypeClass = record->getValueAsBit("customizeTypeClass"); + FmtContext fmt; fmt.addSubst("_type", m_name); { @@ -141,6 +143,35 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context, return true; } +void DialectType::emitTypeClass(llvm::raw_ostream &out, GenDialect *dialect, + FmtContext &fmt) const { + FmtContextScope scope(fmt); + fmt.addSubst("dialect", dialect->name); + fmt.addSubst("_type", getName()); + fmt.addSubst("mnemonic", getMnemonic()); + + if (m_customizeTypeClass) { + fmt.addSubst("customize", + tgfmt("$_type::customizeTypeClass(&theClass);", &fmt).str()); + } else { + fmt.addSubst("customize", ""); + } + + out << tgfmt(R"( + static const auto class$_type = ([]() { + ::llvm::TargetExtTypeClass theClass("$dialect.$mnemonic"); + theClass.setVerifier( + [](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) { + return ::llvm::cast<$_type>(T)->verifier(errs); + }); + $customize + return theClass; + })(); + $_context.registerTargetExtTypeClass(&class$_type); + )", + &fmt); +} + void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { FmtContext fmt; fmt.withContext(m_context); @@ -174,6 +205,12 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { } out << ");\n\n"; + if (m_customizeTypeClass) { + out << R"( + static void customizeTypeClass(::llvm::TargetExtTypeClass *typeClass); + )"; + } + for (const auto &argument : typeArguments()) { out << tgfmt("$0 get$1() const;\n", &fmt, argument.type->getCppType(), convertToCamelFromSnakeCase(argument.name, true)); diff --git a/lib/TableGen/GenDialect.cpp b/lib/TableGen/GenDialect.cpp index d4a9313..4ef4033 100644 --- a/lib/TableGen/GenDialect.cpp +++ b/lib/TableGen/GenDialect.cpp @@ -73,6 +73,17 @@ void llvm_dialects::genDialectDecls(raw_ostream& out, RecordKeeper& records) { #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" +)"; + + if (!dialect->types.empty()) { + out << R"( +namespace llvm { +class TargetExtTypeClass; +} // namespace llvm +)"; + } + + out << R"( namespace llvm { class raw_ostream; } // namespace llvm @@ -236,6 +247,12 @@ void llvm_dialects::genDialectDefs(raw_ostream& out, RecordKeeper& records) { )"; } + if (!dialect->types.empty()) { + out << R"( +#include "llvm/IR/TargetExtType.h" + )"; + } + out << R"( #include "llvm/Support/raw_ostream.h" #endif // GET_INCLUDES @@ -322,6 +339,12 @@ void llvm_dialects::genDialectDefs(raw_ostream& out, RecordKeeper& records) { } } + FmtContextScope scope{fmt}; + fmt.withContext("context"); + + for (DialectType *type : dialect->types) + type->emitTypeClass(out, dialect, fmt); + out << "}\n\n"; // Type class definitions. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2d09a4e..0280c1d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -23,7 +23,7 @@ # ####################################################################################################################### -set(LLVM_DIALECTS_TEST_DEPENDS FileCheck count not llvm-dialects-example) +set(LLVM_DIALECTS_TEST_DEPENDS FileCheck count not split-file llvm-dialects-example) add_custom_target(llvm-dialects-test-depends DEPENDS ${LLVM_DIALECTS_TEST_DEPENDS}) set_target_properties(llvm-dialects-test-depends PROPERTIES FOLDER "Tests") diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 9923497..3c1f454 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -12,6 +12,8 @@ #include "llvm/Support/ModRef.h" +#include "llvm/IR/TargetExtType.h" + #include "llvm/Support/raw_ostream.h" #endif // GET_INCLUDES @@ -138,7 +140,29 @@ attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::ModRef)); m_attributeLists[3] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } -} + + static const auto classXdHandleType = ([]() { + ::llvm::TargetExtTypeClass theClass("xd.handle"); + theClass.setVerifier( + [](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) { + return ::llvm::cast(T)->verifier(errs); + }); + + return theClass; + })(); + context.registerTargetExtTypeClass(&classXdHandleType); + + static const auto classXdVectorType = ([]() { + ::llvm::TargetExtTypeClass theClass("xd.vector"); + theClass.setVerifier( + [](::llvm::TargetExtType *T, ::llvm::raw_ostream &errs) { + return ::llvm::cast(T)->verifier(errs); + }); + XdVectorType::customizeTypeClass(&theClass); + return theClass; + })(); + context.registerTargetExtTypeClass(&classXdVectorType); + } XdHandleType* XdHandleType::get(::llvm::LLVMContext & ctx) { ::std::array<::llvm::Type *, 0> types = { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 21e59e2..758ff7a 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -7,6 +7,11 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" + +namespace llvm { +class TargetExtTypeClass; +} // namespace llvm + namespace llvm { class raw_ostream; } // namespace llvm @@ -80,7 +85,9 @@ namespace xd { static XdVectorType *get(VectorKind kind, ::llvm::Type * elementType, uint32_t numElements); -VectorKind getKind() const; + + static void customizeTypeClass(::llvm::TargetExtTypeClass *typeClass); + VectorKind getKind() const; ::llvm::Type * getElementType() const; uint32_t getNumElements() const; }; diff --git a/test/example/test-builder.test b/test/example/test-builder.test index d15efca..fd41e0d 100644 --- a/test/example/test-builder.test +++ b/test/example/test-builder.test @@ -5,32 +5,35 @@ ; CHECK-LABEL: @example( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read.i32() -; CHECK-NEXT: [[TMP1:%.*]] = call i64 (...) @xd.sizeof(double poison) -; CHECK-NEXT: [[TMP2:%.*]] = call i32 (...) @xd.itrunc.i32(i64 [[TMP1]]) -; CHECK-NEXT: [[TMP3:%.*]] = call i32 @xd.add32(i32 [[TMP0]], i32 [[TMP2]], i32 7) -; CHECK-NEXT: [[TMP4:%.*]] = call i32 (...) @xd.combine.i32(i32 [[TMP3]], i32 [[TMP0]]) -; CHECK-NEXT: [[TMP5:%.*]] = call i64 (...) @xd.iext.i64(i32 [[TMP4]]) -; CHECK-NEXT: call void (...) @xd.write(i64 [[TMP5]]) -; CHECK-NEXT: [[TMP6:%.*]] = call <2 x i32> @xd.read.v2i32() -; CHECK-NEXT: [[TMP7:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.fromfixedvector.txd.vector_i32_1_2t(<2 x i32> [[TMP6]]) -; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 1, 4) @xd.read.txd.vector_i32_1_4t() -; CHECK-NEXT: [[TMP9:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 [[TMP0]]) -; CHECK-NEXT: [[TMP10:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 2) -; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP9]], [[TMP10]] -; CHECK-NEXT: [[TMP12:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP7]], i32 [[TMP11]], i32 [[TMP0]]) -; CHECK-NEXT: [[TMP13:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP12]], i32 [[TMP9]], i32 1) -; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 1, 2) [[TMP13]]) -; CHECK-NEXT: [[TMP14:%.*]] = call ptr @xd.read.p0() -; CHECK-NEXT: [[TMP15:%.*]] = call i8 (...) @xd.stream.add.i8(ptr [[TMP14]], i64 14, i8 0) -; CHECK-NEXT: call void (...) @xd.write(i8 [[TMP15]]) -; CHECK-NEXT: [[TMP16:%.*]] = call target("xd.handle") @xd.handle.get() -; CHECK-NEXT: [[TMP17:%.*]] = call [[TMP0]] @xd.read.s_s() -; CHECK-NEXT: [[TMP18:%.*]] = call [[TMP1]] @xd.read.s_s_0() -; CHECK-NEXT: [[TMP19:%.*]] = call [[TMP2]] @xd.read.s_s_1() -; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP17]]) -; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP18]]) -; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP19]]) +; CHECK-NEXT: [[TMP0:%.*]] = alloca target("xd.vector", i32, 1, 4), align 16 +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @xd.read.i32() +; CHECK-NEXT: [[TMP2:%.*]] = call i64 (...) @xd.sizeof(double poison) +; CHECK-NEXT: [[TMP3:%.*]] = call i32 (...) @xd.itrunc.i32(i64 [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @xd.add32(i32 [[TMP1]], i32 [[TMP3]], i32 7) +; CHECK-NEXT: [[TMP5:%.*]] = call i32 (...) @xd.combine.i32(i32 [[TMP4]], i32 [[TMP1]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i64 (...) @xd.iext.i64(i32 [[TMP5]]) +; CHECK-NEXT: call void (...) @xd.write(i64 [[TMP6]]) +; CHECK-NEXT: [[TMP7:%.*]] = call <2 x i32> @xd.read.v2i32() +; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.fromfixedvector.txd.vector_i32_1_2t(<2 x i32> [[TMP7]]) +; CHECK-NEXT: [[TMP9:%.*]] = call target("xd.vector", i32, 1, 4) @xd.read.txd.vector_i32_1_4t() +; CHECK-NEXT: store target("xd.vector", i32, 1, 4) [[TMP9]], ptr [[TMP0]], align 16 +; CHECK-NEXT: [[TMP10:%.*]] = load target("xd.vector", i32, 1, 4), ptr [[TMP0]], align 16 +; CHECK-NEXT: [[TMP11:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP10]], i32 [[TMP1]]) +; CHECK-NEXT: [[TMP12:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP10]], i32 2) +; CHECK-NEXT: [[TMP13:%.*]] = add i32 [[TMP11]], [[TMP12]] +; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP8]], i32 [[TMP13]], i32 [[TMP1]]) +; CHECK-NEXT: [[TMP15:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP14]], i32 [[TMP11]], i32 1) +; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 1, 2) [[TMP15]]) +; CHECK-NEXT: [[TMP16:%.*]] = call ptr @xd.read.p0() +; CHECK-NEXT: [[TMP17:%.*]] = call i8 (...) @xd.stream.add.i8(ptr [[TMP16]], i64 14, i8 0) +; CHECK-NEXT: call void (...) @xd.write(i8 [[TMP17]]) +; CHECK-NEXT: [[TMP18:%.*]] = call target("xd.handle") @xd.handle.get() +; CHECK-NEXT: [[TMP19:%.*]] = call [[TMP0]] @xd.read.s_s() +; CHECK-NEXT: [[TMP20:%.*]] = call [[TMP1]] @xd.read.s_s_0() +; CHECK-NEXT: [[TMP21:%.*]] = call [[TMP2]] @xd.read.s_s_1() +; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP19]]) +; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP20]]) +; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP21]]) ; CHECK-NEXT: ret void ; ; diff --git a/test/example/verifier-type.ll b/test/example/verifier-type.ll new file mode 100644 index 0000000..3d7ea63 --- /dev/null +++ b/test/example/verifier-type.ll @@ -0,0 +1,18 @@ +; RUN: split-file %s %t +; RUN: not llvm-dialects-example -verify %t/bad-parameters.ll 2>&1 | FileCheck --check-prefixes=CHECK %t/bad-parameters.ll +; RUN: not llvm-dialects-example -verify %t/bad-type-info.ll 2>&1 | FileCheck --check-prefixes=CHECK %t/bad-type-info.ll + +;--- bad-parameters.ll + +; CHECK: [[@LINE+4]]:35: error: target type failed validation: +; CHECK: wrong number of int parameters +; CHECK: expected: 2 +; CHECK: actual: 3 +declare void @test_bad_parameters(target("xd.vector", i32, 1, 4, 5)) + +;--- bad-type-info.ll + +; CHECK: [[@LINE+1]]:1: error: target type has wrong layout type +type target("xd.vector", i32, 1, 2) { + layout: type <4 x i32>, +}