From c38dfcf64681b3498e9658de97ed1d2c63b20165 Mon Sep 17 00:00:00 2001 From: Dmitry Vodopyanov Date: Thu, 19 Oct 2023 18:02:13 +0200 Subject: [PATCH] [SYCL][Joint Matrix][NFC] Add SYCLPropagateJointMatrixUsage pass (#11508) This patch adds a pass which propagates optional kernel features metadata through a module call graph for sycl_ext_oneapi_matrix extension. According to the extension spec, optional kernel features are the `joint_matrix` type and the `joint_matrix_mad` function. --- clang/lib/CodeGen/BackendUtil.cpp | 13 +- .../SYCLPropagateJointMatrixUsage.h | 30 +++ llvm/lib/Passes/PassBuilder.cpp | 3 +- llvm/lib/Passes/PassRegistry.def | 1 + llvm/lib/SYCLLowerIR/CMakeLists.txt | 1 + .../SYCLPropagateJointMatrixUsage.cpp | 244 ++++++++++++++++++ .../call-graph-joint-matrix.ll | 118 +++++++++ .../oneapi/matrix/matrix-unified-utils.hpp | 20 ++ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 24 +- .../include/sycl/ext/oneapi/matrix/matrix.hpp | 1 - .../sycl/ext/oneapi/matrix/query-types.hpp | 52 ++++ .../matrix-check-types-in-attributes.cpp | 52 ++++ sycl/test/matrix/matrix-int8-test.cpp | 3 + 13 files changed, 552 insertions(+), 10 deletions(-) create mode 100644 llvm/include/llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h create mode 100644 llvm/lib/SYCLLowerIR/SYCLPropagateJointMatrixUsage.cpp create mode 100644 llvm/test/SYCLLowerIR/PropagateJointMatrixUsage/call-graph-joint-matrix.ll create mode 100644 sycl/test/matrix/matrix-check-types-in-attributes.cpp diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp index 81eee4ab4e2c0..df3a2e568438c 100644 --- a/clang/lib/CodeGen/BackendUtil.cpp +++ b/clang/lib/CodeGen/BackendUtil.cpp @@ -53,6 +53,7 @@ #include "llvm/SYCLLowerIR/RenameKernelSYCLNativeCPU.h" #include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h" #include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h" +#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h" #include "llvm/Support/BuryPointer.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" @@ -966,12 +967,12 @@ void EmitAssemblyHelper::RunOptimizationPipeline( OptimizationLevel Level = mapToLevel(CodeGenOpts); if (LangOpts.SYCLIsDevice) - PB.registerPipelineStartEPCallback( - [&](ModulePassManager &MPM, OptimizationLevel Level) { - MPM.addPass(ESIMDVerifierPass(LangOpts.SYCLESIMDForceStatelessMem)); - MPM.addPass( - SYCLPropagateAspectsUsagePass(/*ExcludeAspects=*/{"fp64"})); - }); + PB.registerPipelineStartEPCallback([&](ModulePassManager &MPM, + OptimizationLevel Level) { + MPM.addPass(ESIMDVerifierPass(LangOpts.SYCLESIMDForceStatelessMem)); + MPM.addPass(SYCLPropagateAspectsUsagePass(/*ExcludeAspects=*/{"fp64"})); + MPM.addPass(SYCLPropagateJointMatrixUsagePass()); + }); else if (LangOpts.SYCLIsHost && !LangOpts.SYCLESIMDBuildHostCode) PB.registerPipelineStartEPCallback( [&](ModulePassManager &MPM, OptimizationLevel Level) { diff --git a/llvm/include/llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h b/llvm/include/llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h new file mode 100644 index 0000000000000..66430bea9d343 --- /dev/null +++ b/llvm/include/llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h @@ -0,0 +1,30 @@ +//===- SYCLPropagateJointMatrixUsage.cpp - SYCLPropagateJointMatrixUsage Pass +//-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass propagates optional kernel features metadata through a module call graph +// for sycl_ext_oneapi_matrix extension +// +//===----------------------------------------------------------------------===// +// +#ifndef LLVM_SYCL_PROPAGATE_JOINT_MATRIX_USAGE_H +#define LLVM_SYCL_PROPAGATE_JOINT_MATRIX_USAGE_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class SYCLPropagateJointMatrixUsagePass + : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &); +}; + +} // namespace llvm + +#endif // LLVM_SYCL_PROPAGATE_JOINT_MATRIX_USAGE_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index cc5db9365a420..c47e469c469cf 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -92,6 +92,7 @@ #include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h" #include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h" #include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h" +#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -244,8 +245,8 @@ #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/Transforms/Utils/CanonicalizeFreezeInLoops.h" #include "llvm/Transforms/Utils/CountVisits.h" -#include "llvm/Transforms/Utils/Debugify.h" #include "llvm/Transforms/Utils/DXILUpgrade.h" +#include "llvm/Transforms/Utils/Debugify.h" #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" #include "llvm/Transforms/Utils/FixIrreducible.h" #include "llvm/Transforms/Utils/HelloWorld.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index db3c820418c88..f422a0d47ee4d 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -140,6 +140,7 @@ MODULE_PASS("sycllowerwglocalmemory", SYCLLowerWGLocalMemoryPass()) MODULE_PASS("lower-esimd-kernel-attrs", SYCLFixupESIMDKernelWrapperMDPass()) MODULE_PASS("esimd-remove-host-code", ESIMDRemoveHostCodePass()); MODULE_PASS("sycl-propagate-aspects-usage", SYCLPropagateAspectsUsagePass()) +MODULE_PASS("sycl-propagate-joint-matrix-usage", SYCLPropagateJointMatrixUsagePass()) MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass()) MODULE_PASS("compile-time-properties", CompileTimePropertiesPass()) MODULE_PASS("cleanup-sycl-metadata", CleanupSYCLMetadataPass()) diff --git a/llvm/lib/SYCLLowerIR/CMakeLists.txt b/llvm/lib/SYCLLowerIR/CMakeLists.txt index b02efc582663a..f43bd78cdd70a 100644 --- a/llvm/lib/SYCLLowerIR/CMakeLists.txt +++ b/llvm/lib/SYCLLowerIR/CMakeLists.txt @@ -65,6 +65,7 @@ add_llvm_component_library(LLVMSYCLLowerIR MutatePrintfAddrspace.cpp SYCLAddOptLevelAttribute.cpp SYCLPropagateAspectsUsage.cpp + SYCLPropagateJointMatrixUsage.cpp SYCLUtils.cpp LocalAccessorToSharedMemory.cpp diff --git a/llvm/lib/SYCLLowerIR/SYCLPropagateJointMatrixUsage.cpp b/llvm/lib/SYCLLowerIR/SYCLPropagateJointMatrixUsage.cpp new file mode 100644 index 0000000000000..5edc38e9f2be5 --- /dev/null +++ b/llvm/lib/SYCLLowerIR/SYCLPropagateJointMatrixUsage.cpp @@ -0,0 +1,244 @@ +//===------------------ SYCLPropagateJointMatrixUsage.cpp -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass propagates optional kernel features metadata through a module call graph +// for sycl_ext_oneapi_matrix extension +// +// The pass consists of three main steps: +// +// I. It builds Function -> string of joint matrix types and sizes values +// mapping for usage in step II +// II. Propagates all the values from step I. to the top of the call graph +// III. Generates metadata with values of joint matrix types and sizes +// +//===----------------------------------------------------------------------===// + +#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h" + +#include "llvm/ADT/SmallString.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" + +#include + +using namespace llvm; + +namespace { + +/// Returns true if the function is a SPIRV or SYCL builtin, e.g. +/// _Z28__spirv_GlobalInvocationId_xv +/// NB! This function was copied from sycl-post-link/ModuleSplitter.cpp and the +/// definition of entry point (i.e. implementation of the function) should be in +/// sync between those two. +bool isSpirvSyclBuiltin(StringRef FName) { + if (!FName.consume_front("_Z")) + return false; + // now skip the digits + FName = FName.drop_while([](char C) { return std::isdigit(C); }); + + return FName.startswith("__spirv_") || FName.startswith("__sycl_"); +} + +bool isEntryPoint(const Function &F) { + // Skip declarations, we can't analyze them + if (F.isDeclaration()) { + // F.print(outs()); + return false; + } + + // Kernels are always considered to be entry points + if (CallingConv::SPIR_KERNEL == F.getCallingConv()) + return true; + + // FIXME: sycl-post-link allows to disable treating SYCL_EXTERNAL's as entry + // points - do we need similar flag here? + // SYCL_EXTERNAL functions with sycl-module-id attribute + // are also considered as entry points (except __spirv_* and __sycl_* + // functions) + return F.hasFnAttribute("sycl-module-id") && !isSpirvSyclBuiltin(F.getName()); +} + +using CallGraphTy = DenseMap>; + +/// Updates call graph with the information from function @F +void fillCallGraph(Function *F, CallGraphTy &CG) { + for (Instruction &I : instructions(F)) { + if (const auto *CI = dyn_cast(&I)) { + if (!CI->isIndirectCall() && CI->getCalledFunction()) + CG[F].insert(CI->getCalledFunction()); + } + } +} + +using JointMatrixValueStringTy = SmallString<40>; +using JointMatrixValuesSetTy = std::set; +using FunctionToJointMatrixValuesMapTy = + DenseMap; + +/// Creates mapping between a function and an information about matrix types and +/// sizes of sycl::ext::oneapi::experimental::matrix::joint_matrix type +void fillFunctionToJointMatrixValuesMap( + Function *F, + FunctionToJointMatrixValuesMapTy &FunctionToJointMatrixValues) { + // assume we have other sycl-joint-matrix-* attributes if + // sycl-joint-matrix-type is present + if (!F->hasFnAttribute("sycl-joint-matrix-type")) + return; + + JointMatrixValueStringTy Result; + // NB! The order of attributes must not change as it is used later in SYCL + // RT + // The order is: + // - sycl-joint-matrix-type + // - sycl-joint-matrix-use + // - sycl-joint-matrix-rows + // - sycl-joint-matrix-cols + // NB! Values must be separated with a comma + Result += F->getFnAttribute("sycl-joint-matrix-type").getValueAsString(); + Result += ","; + Result += F->getFnAttribute("sycl-joint-matrix-use").getValueAsString(); + Result += ","; + Result += F->getFnAttribute("sycl-joint-matrix-rows").getValueAsString(); + Result += ","; + Result += F->getFnAttribute("sycl-joint-matrix-cols").getValueAsString(); + FunctionToJointMatrixValues[F].insert(Result); +} + +/// Creates mapping between a function and an information about matrix types and +/// sizes of sycl::ext::oneapi::experimental::matrix::joint_matrix_mad() +/// function +void fillFunctionToJointMatrixMadValuesMap( + Function *F, + FunctionToJointMatrixValuesMapTy &FunctionToJointMatrixMapValues) { + // assume we have other sycl-joint-matrix-mad-* attributes if + // sycl-joint-matrix-mad-type-A is present + if (!F->hasFnAttribute("sycl-joint-matrix-mad-type-A")) + return; + + JointMatrixValueStringTy Result; + // NB! The order of attributes must not change as it is used later in SYCL + // RT + // The order is: + // - sycl-joint-matrix-mad-type-A + // - sycl-joint-matrix-mad-type-B + // - sycl-joint-matrix-mad-type-C + // - sycl-joint-matrix-mad-type-D + // - sycl-joint-matrix-mad-size-M + // - sycl-joint-matrix-mad-size-K + // - sycl-joint-matrix-mad-size-N + // NB! Values must be separated with a comma + Result += + F->getFnAttribute("sycl-joint-matrix-mad-type-A").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-type-B").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-type-C").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-type-D").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-size-M").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-size-K").getValueAsString(); + Result += ","; + Result += + F->getFnAttribute("sycl-joint-matrix-mad-size-N").getValueAsString(); + FunctionToJointMatrixMapValues[F].insert(Result); +} + +/// Propagates joint matrix values from leaves up to the top of call graph. +/// NB! Call graph corresponds to call graph of SYCL code which +/// can't contain recursive calls. So there can't be loops in +/// a call graph. But there can be path's intersections. +void propagateJointMatrixValuesThroughCG( + Function *F, CallGraphTy &CG, + FunctionToJointMatrixValuesMapTy &FunctionToJointMatrixValues, + FunctionToJointMatrixValuesMapTy &FunctionToJointMatrixMadValues, + SmallPtrSet &Visited) { + const auto It = CG.find(F); + if (It == CG.end()) + return; + + JointMatrixValuesSetTy LocalJointMatrixValues; + JointMatrixValuesSetTy LocalJointMatrixMadValues; + for (Function *Callee : It->second) { + if (Visited.insert(Callee).second) + propagateJointMatrixValuesThroughCG( + Callee, CG, FunctionToJointMatrixValues, + FunctionToJointMatrixMadValues, Visited); + + const auto &CalleeJointMatrixValues = FunctionToJointMatrixValues[Callee]; + LocalJointMatrixValues.insert(CalleeJointMatrixValues.begin(), + CalleeJointMatrixValues.end()); + const auto &CalleeJointMatrixMadValues = + FunctionToJointMatrixMadValues[Callee]; + LocalJointMatrixMadValues.insert(CalleeJointMatrixMadValues.begin(), + CalleeJointMatrixMadValues.end()); + } + FunctionToJointMatrixValues[F].insert(LocalJointMatrixValues.begin(), + LocalJointMatrixValues.end()); + FunctionToJointMatrixMadValues[F].insert(LocalJointMatrixMadValues.begin(), + LocalJointMatrixMadValues.end()); +} + +void setSyclJointMatrixMetadata(StringRef MetadataName, Module *M, Function *F, + FunctionToJointMatrixValuesMapTy ValuesMap) { + JointMatrixValuesSetTy Values = ValuesMap[F]; + SmallString<256> StringValue; + for (auto It = Values.begin(); It != Values.end(); It++) { + StringValue += *It; + // NB! Each information about joint_matrix type and joint_matrix_mad + // function should be separated by a semicolon + if (std::next(It) != Values.end()) + StringValue += ";"; + } + if (StringValue.empty()) + return; + + MDString *MDStringValue = MDString::get(M->getContext(), StringValue); + MDNode *MDN = MDNode::get(M->getContext(), MDStringValue); + F->setMetadata(MetadataName, MDN); +} + +} // anonymous namespace + +PreservedAnalyses +SYCLPropagateJointMatrixUsagePass::run(Module &M, ModuleAnalysisManager &MAM) { + FunctionToJointMatrixValuesMapTy FunctionToJointMatrixValues; + FunctionToJointMatrixValuesMapTy FunctionToJointMatrixMadValues; + SmallVector EntryPoints; + CallGraphTy CG; + for (Function &F : M.functions()) { + fillFunctionToJointMatrixValuesMap(&F, FunctionToJointMatrixValues); + fillFunctionToJointMatrixMadValuesMap(&F, FunctionToJointMatrixMadValues); + fillCallGraph(&F, CG); + + if (isEntryPoint(F)) + EntryPoints.push_back(&F); + } + + SmallPtrSet Visited; + for (const auto F : EntryPoints) { + propagateJointMatrixValuesThroughCG(F, CG, FunctionToJointMatrixValues, + FunctionToJointMatrixMadValues, + Visited); + } + + for (Function *F : EntryPoints) { + setSyclJointMatrixMetadata("sycl_joint_matrix", &M, F, + FunctionToJointMatrixValues); + setSyclJointMatrixMetadata("sycl_joint_matrix_mad", &M, F, + FunctionToJointMatrixMadValues); + } + + return PreservedAnalyses::all(); +} diff --git a/llvm/test/SYCLLowerIR/PropagateJointMatrixUsage/call-graph-joint-matrix.ll b/llvm/test/SYCLLowerIR/PropagateJointMatrixUsage/call-graph-joint-matrix.ll new file mode 100644 index 0000000000000..5f65c05587525 --- /dev/null +++ b/llvm/test/SYCLLowerIR/PropagateJointMatrixUsage/call-graph-joint-matrix.ll @@ -0,0 +1,118 @@ +; RUN: opt -passes=sycl-propagate-joint-matrix-usage < %s -S | FileCheck %s +; +; Test checks that the pass is able to propagate information about used joint_matrix +; through a call graph +; +; K1 K2 F5 +; | / \ / \ +; | F4 JMM1 JM5 JM6 +; | / \ +; F1 JM4 +; / | \ +; JM1 F2 F3 +; | \ +; JM2 JM3 +; +; +; K* - kernels +; F* - functions +; JM* - joint_matrix ctors +; JMM1 - joint_matrix_mad function + +; CHECK: define spir_kernel void @kernel1() !sycl_joint_matrix ![[#ID0:]] { +define spir_kernel void @kernel1() { + call spir_func void @func1() + ret void +} + +; CHECK: define spir_kernel void @kernel2() !sycl_joint_matrix ![[#ID1:]] !sycl_joint_matrix_mad ![[#ID2:]] { +define spir_kernel void @kernel2() { + call spir_func void @func4() + call spir_func void @joint_matrix_mad1() + ret void +} + +; CHECK: define spir_func void @func1() #0 !sycl_joint_matrix ![[#ID0:]] { +define spir_func void @func1() #0 { + call spir_func void @joint_matrix1() + call spir_func void @func2() + call spir_func void @func3() + ret void +} + +; CHECK: define spir_func void @joint_matrix1() #1 { +define spir_func void @joint_matrix1() #1 { + ret void +} + +; CHECK: define spir_func void @func2() #2 !sycl_joint_matrix ![[#ID3:]] { +define spir_func void @func2() #2 { + call spir_func void @joint_matrix2() + ret void +} + +; CHECK: define spir_func void @joint_matrix2() #3 { +define spir_func void @joint_matrix2() #3 { + ret void +} + +; CHECK: define spir_func void @func3() #4 !sycl_joint_matrix ![[#ID4:]] { +define spir_func void @func3() #4 { + call spir_func void @joint_matrix3() + ret void +} + +; CHECK: define spir_func void @joint_matrix3() #5 { +define spir_func void @joint_matrix3() #5 { + ret void +} + +; CHECK: define spir_func void @func4() #6 !sycl_joint_matrix ![[#ID1:]] { +define spir_func void @func4() #6 { + call spir_func void @joint_matrix4() + call spir_func void @func1() + ret void +} + +; CHECK: define spir_func void @joint_matrix4() #7 { +define spir_func void @joint_matrix4() #7 { + ret void +} + +define spir_func void @joint_matrix_mad1() #8 { + ret void +} + +; CHECK: define spir_func void @func5() #0 !sycl_joint_matrix ![[#ID5:]] { +define spir_func void @func5() #0 { + call spir_func void @joint_matrix5() + call spir_func void @joint_matrix6() + ret void +} + +; CHECK: define spir_func void @joint_matrix5() #1 { +define spir_func void @joint_matrix5() #1 { + ret void +} + +; CHECK: define spir_func void @joint_matrix6() #3 { +define spir_func void @joint_matrix6() #3 { + ret void +} + +attributes #0 = { "sycl-joint-matrix-cols"="48" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::a" "sycl-module-id"="test.cpp" } +attributes #1 = { "sycl-joint-matrix-cols"="48" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::a" } +attributes #2 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="48" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::b" "sycl-module-id"="test.cpp" } +attributes #3 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="48" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::b" } +attributes #4 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::a" "sycl-module-id"="test.cpp" } +attributes #5 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint8" "sycl-joint-matrix-use"="use::a" } +attributes #6 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint32" "sycl-joint-matrix-use"="use::accumulator" "sycl-module-id"="test.cpp" } +attributes #7 = { "sycl-joint-matrix-cols"="12" "sycl-joint-matrix-rows"="12" "sycl-joint-matrix-type"="matrix_type::sint32" "sycl-joint-matrix-use"="use::accumulator" } +attributes #8 = { "sycl-joint-matrix-mad-size-M"="12" "sycl-joint-matrix-mad-size-K"="48" "sycl-joint-matrix-mad-size-N"="12" "sycl-joint-matrix-mad-type-A"="matrix_type::sint8" "sycl-joint-matrix-mad-type-B"="matrix_type::sint8" "sycl-joint-matrix-mad-type-C"="matrix_type::sint32" "sycl-joint-matrix-mad-type-D"="matrix_type::sint32" } + +; CHECK: ![[#ID0]] = !{!"matrix_type::sint8,use::a,12,12;matrix_type::sint8,use::a,12,48;matrix_type::sint8,use::b,48,12"} +; CHECK: ![[#ID1]] = !{!"matrix_type::sint32,use::accumulator,12,12;matrix_type::sint8,use::a,12,12;matrix_type::sint8,use::a,12,48;matrix_type::sint8,use::b,48,12"} +; CHECK: ![[#ID2]] = !{!"matrix_type::sint8,matrix_type::sint8,matrix_type::sint32,matrix_type::sint32,12,48,12"} +; CHECK: ![[#ID3]] = !{!"matrix_type::sint8,use::b,48,12"} +; CHECK: ![[#ID4]] = !{!"matrix_type::sint8,use::a,12,12"} +; CHECK: ![[#ID5]] = !{!"matrix_type::sint8,use::a,12,48;matrix_type::sint8,use::b,48,12"} diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index 8a9dbc12df2ec..68b29fc07c22a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -33,5 +33,25 @@ class tf32 { } // namespace experimental } // namespace oneapi } // namespace ext + +namespace detail { +using UseToUseStringPair = + std::pair; + +constexpr const char * +convertMatrixUseToString(ext::oneapi::experimental::matrix::use Use) { + constexpr UseToUseStringPair UseToUseStringMap[] = { + {ext::oneapi::experimental::matrix::use::a, "use::a"}, + {ext::oneapi::experimental::matrix::use::b, "use::b"}, + {ext::oneapi::experimental::matrix::use::accumulator, "use::accumulator"}, + }; + + for (const auto &Item : UseToUseStringMap) { + if (Item.first == Use) + return Item.second; + } + return ""; +} +} // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 327e1e326f108..8f715b0ef392c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -18,7 +18,8 @@ #include // for __SYCL_ALWAYS_... #include // for PI_ERROR_INVAL... #include // for runtime_error -#include // for layout, use, tf32 +#include // for layout, use, tf32, convertMatrixUseToString +#include // for convertTypeToMatrixTypeString #include // for marray #include // for multi_ptr @@ -54,6 +55,13 @@ struct joint_matrix { #endif // defined(__NVPTX__) #endif // defined(__SYCL_DEVICE_ONLY__) +#if defined(__SYCL_DEVICE_ONLY__) + [[__sycl_detail__::add_ir_attributes_function( + "sycl-joint-matrix-type", "sycl-joint-matrix-use", + "sycl-joint-matrix-rows", "sycl-joint-matrix-cols", + sycl::detail::convertTypeToMatrixTypeString(), + sycl::detail::convertMatrixUseToString(Use), Rows, Cols)]] +#endif // defined(__SYCL_DEVICE_ONLY__) joint_matrix() { #ifndef __SYCL_DEVICE_ONLY__ throw runtime_error("joint matrix is not supported on host device.", @@ -361,7 +369,19 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( template -inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( +#if defined(__SYCL_DEVICE_ONLY__) +[[__sycl_detail__::add_ir_attributes_function( + "sycl-joint-matrix-mad-type-A", "sycl-joint-matrix-mad-type-B", + "sycl-joint-matrix-mad-type-C", "sycl-joint-matrix-mad-type-D", + "sycl-joint-matrix-mad-size-M", "sycl-joint-matrix-mad-size-K", + "sycl-joint-matrix-mad-size-N", + sycl::detail::convertTypeToMatrixTypeString(), + sycl::detail::convertTypeToMatrixTypeString(), + sycl::detail::convertTypeToMatrixTypeString(), + sycl::detail::convertTypeToMatrixTypeString(), M, K, N)]] +#endif // defined(__SYCL_DEVICE_ONLY__) +inline __SYCL_ALWAYS_INLINE void +joint_matrix_mad( Group, joint_matrix &D, diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index ab9cba1141725..77037885fc28b 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -22,6 +22,5 @@ #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION #if (SYCL_EXT_ONEAPI_MATRIX_VERSION == 4) #include -#include #include #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION diff --git a/sycl/include/sycl/ext/oneapi/matrix/query-types.hpp b/sycl/include/sycl/ext/oneapi/matrix/query-types.hpp index aa9539ec5cc66..66decfeb15e2c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/query-types.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/query-types.hpp @@ -8,6 +8,9 @@ #pragma once +#include // for bfloat16 +#include // for tf32 + namespace sycl { inline namespace _V1 { namespace ext::oneapi::experimental::matrix { @@ -42,5 +45,54 @@ struct combination { }; } // namespace ext::oneapi::experimental::matrix + +namespace detail { +template constexpr const char *convertTypeToMatrixTypeString() { + return ""; +} +template <> +constexpr const char * +convertTypeToMatrixTypeString() { + return "matrix_type::bf16"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::fp16"; +} +template <> +constexpr const char *convertTypeToMatrixTypeString< + sycl::ext::oneapi::experimental::matrix::precision::tf32>() { + return "matrix_type::tf32"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::fp32"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::fp64"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::sint8"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::sint16"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::sint32"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::sint64"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::uint8"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::uint16"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::uint32"; +} +template <> constexpr const char *convertTypeToMatrixTypeString() { + return "matrix_type::uint64"; +} +} // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/test/matrix/matrix-check-types-in-attributes.cpp b/sycl/test/matrix/matrix-check-types-in-attributes.cpp new file mode 100644 index 0000000000000..9b31885f79304 --- /dev/null +++ b/sycl/test/matrix/matrix-check-types-in-attributes.cpp @@ -0,0 +1,52 @@ +// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s + +// This test checks the correctness of matrix types converted into strings + +// "matrix_type,use,rows,cols" +// CHECK: !{!"matrix_type::bf16,use::a,12,12"} +// CHECK: !{!"matrix_type::fp16,use::a,12,12"} +// CHECK: !{!"matrix_type::tf32,use::a,12,12"} +// CHECK: !{!"matrix_type::fp32,use::a,12,12"} +// CHECK: !{!"matrix_type::fp64,use::a,12,12"} +// CHECK: !{!"matrix_type::sint8,use::a,12,12"} +// CHECK: !{!"matrix_type::sint16,use::a,12,12"} +// CHECK: !{!"matrix_type::sint32,use::a,12,12"} +// CHECK: !{!"matrix_type::sint64,use::a,12,12"} +// CHECK: !{!"matrix_type::uint8,use::a,12,12"} +// CHECK: !{!"matrix_type::uint16,use::a,12,12"} +// CHECK: !{!"matrix_type::uint32,use::a,12,12"} +// CHECK: !{!"matrix_type::uint64,use::a,12,12"} + +#include + +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr size_t Size = 12; + +template void test(sycl::queue &q) { + q.submit([&](sycl::handler &cgh) { + cgh.single_task([]() { + joint_matrix m; + }); + }); +} + +int main() { + sycl::queue q; + + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + test(q); + + return 0; +} diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index c4ab58c1deaec..b13cb23ae73b0 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -4,6 +4,9 @@ // CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) // CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) +// CHECK: !{!"matrix_type::sint32,use::accumulator,12,12;matrix_type::sint8,use::a,12,48;matrix_type::sint8,use::b,48,12"} +// CHECK: !{!"matrix_type::sint8,matrix_type::sint8,matrix_type::sint32,matrix_type::sint32,12,48,12"} + #include #include