Skip to content

Commit

Permalink
[SYCL][Fusion] Refine remapping of GEP instruction during internaliza…
Browse files Browse the repository at this point in the history
…tion (intel#12128)

So far, we distinguished GEP instructions that select an element of the
internalized buffer (must be remapped) and GEPs that address into an
aggregate type (must _not_ be remapped) by looking at the number of
indices.

However, we can also encounter single-index GEP instructions that use a
byte-offset to address into padded structures, as well as multi-index
GEPs with a base pointer offset that address into an aggregate _and_
need to be remapped.

This PR adds uses the newly added element size information (intel#12108) to
correctly distinguish the required action for these kinds of GEP
instructions. In addition to the E2E tests, target-specific lit tests
derived from it are also added to demonstrate the subtle differences
among SPIR-V, CUDA and HIP.

---------

Signed-off-by: Julian Oppermann <[email protected]>
  • Loading branch information
jopperm authored Jan 3, 2024
1 parent 8ab4907 commit a90aaa7
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 24 deletions.
82 changes: 58 additions & 24 deletions sycl-fusion/passes/internalization/Internalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <llvm/ADT/BitVector.h>
#include <llvm/ADT/TypeSwitch.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/PatternMatch.h>
#include <llvm/Support/WithColor.h>
#include <llvm/Transforms/Utils/Cloning.h>

Expand All @@ -25,6 +26,7 @@
#define DEBUG_TYPE "sycl-fusion"

using namespace llvm;
using namespace PatternMatch;

constexpr static StringLiteral PrivatePromotion{"private"};
constexpr static StringLiteral LocalPromotion{"local"};
Expand Down Expand Up @@ -191,22 +193,10 @@ static void updateInternalizationMD(Function *F, StringRef Kind,
/// address space has changed from N to N / LocalSize.
static void remap(GetElementPtrInst *GEPI, const PromotionInfo &PromInfo) {
IRBuilder<> Builder{GEPI};
Value *C0 = Builder.getInt64(0);

auto NIdx = GEPI->getNumIndices();
if (NIdx > 1) {
// `GEPI` indexes into an aggregate. If the first index is 0, the base
// pointer is used as-is and we do not need to perform remapping. This is
// the common case.
// TODO: Support non-zero pointer offset, too. If the pointer operand is
// a GEP as well, we must check if the source element types match.
assert(GEPI->idx_begin()->get() == C0);
return;
}

if (PromInfo.LocalSize == 1) {
// Squash the index and let instcombine clean-up afterwards.
GEPI->idx_begin()->set(C0);
GEPI->idx_begin()->set(Builder.getInt64(0));
return;
}

Expand Down Expand Up @@ -290,6 +280,43 @@ Error SYCLInternalizerImpl::canPromoteCall(CallBase *C, const Value *Val,
return Error::success();
}

enum GEPKind { INVALID = 0, NEEDS_REMAPPING, ADDRESSES_INTO_AGGREGATE };

static int getGEPKind(GetElementPtrInst *GEPI, const PromotionInfo &PromInfo) {
assert(GEPI->getNumIndices() >= 1 && "No-op GEP encountered");

// Inspect the GEP's source element type.
auto &DL = GEPI->getModule()->getDataLayout();
auto SrcElemTySz = DL.getTypeAllocSize(GEPI->getSourceElementType());

// `GEPI`'s first index is selecting elements. Unless it is constant zero, we
// have to remap. If there are more indices, we start to address into an
// aggregate type.
if (SrcElemTySz == PromInfo.ElemSize) {
int Kind = INVALID;
if (!match(GEPI->idx_begin()->get(), m_ZeroInt()))
Kind |= NEEDS_REMAPPING;
if (GEPI->getNumIndices() >= 2)
Kind |= ADDRESSES_INTO_AGGREGATE;
assert(Kind != INVALID && "No-op GEP encountered");
return Kind;
}

// Check whether `GEPI` adds a constant offset, e.g. a byte offset to address
// into a padded structure, smaller than the element size.
MapVector<Value *, APInt> VariableOffsets;
auto IW = DL.getIndexSizeInBits(GEPI->getPointerAddressSpace());
APInt ConstantOffset = APInt::getZero(IW);
if (GEPI->collectOffset(DL, IW, VariableOffsets, ConstantOffset) &&
VariableOffsets.empty() &&
ConstantOffset.getZExtValue() < PromInfo.ElemSize) {
return ADDRESSES_INTO_AGGREGATE;
}

// We don't know what `GEPI` addresses; bail out.
return INVALID;
}

Error SYCLInternalizerImpl::canPromoteGEP(GetElementPtrInst *GEPI,
const Value *Val,
const PromotionInfo &PromInfo,
Expand All @@ -299,12 +326,17 @@ Error SYCLInternalizerImpl::canPromoteGEP(GetElementPtrInst *GEPI,
// required.
return Error::success();
}
// Recurse to check all users of the GEP. We are either already in
// `InAggregate` mode, or inspect the current instruction. Recall that a GEP's
// first index is used to step through the base pointer, whereas any
// additional indices represent addressing into an aggregrate type.

// Inspect the current instruction.
auto Kind = getGEPKind(GEPI, PromInfo);
if (Kind == INVALID) {
return createStringError(inconvertibleErrorCode(),
"Unsupported pointer arithmetic");
}

// Recurse to check all users of the GEP.
return canPromoteValue(GEPI, PromInfo,
InAggregate || GEPI->getNumIndices() >= 2);
InAggregate || (Kind & ADDRESSES_INTO_AGGREGATE));
}

Error SYCLInternalizerImpl::canPromoteValue(Value *Val,
Expand Down Expand Up @@ -423,15 +455,17 @@ void SYCLInternalizerImpl::promoteGEPI(GetElementPtrInst *GEPI,
bool InAggregate) const {
// Not PointerType is unreachable. Other case is caught in caller.
if (cast<PointerType>(GEPI->getType())->getAddressSpace() != AS) {
if (!InAggregate)
auto Kind = getGEPKind(GEPI, PromInfo);
assert(Kind != INVALID);

if (!InAggregate && (Kind & NEEDS_REMAPPING)) {
remap(GEPI, PromInfo);
}
GEPI->mutateType(PointerType::get(GEPI->getContext(), AS));
// Recurse to promote to all users of the GEP. We are either already in
// `InAggregate` mode, or inspect the current instruction. Recall that a
// GEP's first index is used to step through the base pointer, whereas any
// additional indices represent addressing into an aggregrate type.

// Recurse to promote to all users of the GEP.
return promoteValue(GEPI, PromInfo,
InAggregate || GEPI->getNumIndices() >= 2);
InAggregate || (Kind & ADDRESSES_INTO_AGGREGATE));
}
}

Expand Down
116 changes: 116 additions & 0 deletions sycl-fusion/test/internalization/promote-private-non-unit-cuda.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
; REQUIRES: cuda
; RUN: opt -load-pass-plugin %shlibdir/SYCLKernelFusion%shlibext \
; RUN: -passes=sycl-internalization -S %s | FileCheck %s

; This test is a reduced IR version of
; sycl/test-e2e/KernelFusion/internalize_non_unit_localsize.cpp for CUDA

target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
%"class.sycl::_V1::detail::array" = type { [1 x i64] }
%struct.MyStruct = type { i32, %"class.sycl::_V1::vec" }
%"class.sycl::_V1::vec" = type { <3 x i32> }

declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #0
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
declare ptr @llvm.nvvm.implicit.offset() #1

define void @fused_0(ptr addrspace(1) nocapture noundef align 16 %KernelOne__arg_accTmp,
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accTmp3,
ptr addrspace(1) nocapture noundef readonly align 4 %KernelOne__arg_accIn,
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accIn6,
ptr addrspace(1) nocapture noundef align 1 %KernelOne__arg_accTmp27,
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelOne__arg_accTmp210,
ptr addrspace(1) nocapture noundef writeonly align 4 %KernelTwo__arg_accOut,
ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 %KernelTwo__arg_accOut3)
local_unnamed_addr #3 !sycl.kernel.promote !13 !sycl.kernel.promote.localsize !14 !sycl.kernel.promote.elemsize !15 {
; CHECK-LABEL: define void @fused_0(
; CHECK-SAME: ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 [[KERNELONE__ARG_ACCTMP3:%[^,]*accTmp3]],
; CHECK-SAME: ptr nocapture noundef readonly byval(%"class.sycl::_V1::range") align 8 [[KERNELONE__ARG_ACCTMP210:%[^,]*accTmp210]]
; CHECK: entry:
; CHECK: [[TMP0:%.*]] = alloca i8, i64 3, align 1
; CHECK: [[TMP1:%.*]] = alloca i8, i64 96, align 16
; CHECK: [[KERNELONE__ARG_ACCTMP2103_SROA_0_0_COPYLOAD:%.*]] = load i64, ptr [[KERNELONE__ARG_ACCTMP210]], align 8
; CHECK: [[KERNELONE__ARG_ACCTMP31_SROA_0_0_COPYLOAD:%.*]] = load i64, ptr [[KERNELONE__ARG_ACCTMP3]], align 8
; CHECK: [[TMP2:%.*]] = urem i64 [[KERNELONE__ARG_ACCTMP31_SROA_0_0_COPYLOAD]], 3
; CHECK: [[TMP3:%.*]] = urem i64 [[KERNELONE__ARG_ACCTMP2103_SROA_0_0_COPYLOAD]], 3
; CHECK: [[MUL:%.*]] = mul nuw nsw i64 [[GLOBAL_ID:.*]], 3
; CHECK: [[ADD:%.*]] = add nuw nsw i64 [[MUL]], 1
; CHECK: [[TMP10:%.*]] = add i64 [[TMP2]], [[ADD]]
; CHECK: [[TMP11:%.*]] = urem i64 [[TMP10]], 3
; CHECK: [[ARRAYIDX_1:%.*]] = getelementptr inbounds %struct.MyStruct, ptr [[TMP1]], i64 [[TMP11]]

; COM: This i8-GEP _was_ not remapped because it addresses into a single MyStruct element
; CHECK: [[ARRAYIDX_2:%.*]] = getelementptr inbounds i8, ptr [[ARRAYIDX_1]], i64 20
; CHECK: store i32 {{.*}}, ptr [[ARRAYIDX_2]], align 4
; CHECK: [[TMP12:%.*]] = add i64 [[TMP3]], [[ADD]]
; CHECK: [[TMP13:%.*]] = urem i64 [[TMP12]], 3

; COM: This i8-GEP was remapped because it selects an element of the underlying i8-buffer
; CHECK: [[ARRAYIDX_3:%.*]] = getelementptr inbounds i8, ptr [[TMP0]], i64 [[TMP13]]

; CHECK: store i8 {{.*}}, ptr [[ARRAYIDX_3]], align 1
; CHECK: store i32 {{.*}}, ptr addrspace(1)
; CHECK: ret void
;
entry:
%KernelOne__arg_accTmp2103.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accTmp210, align 8
%KernelOne__arg_accIn62.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accIn6, align 8
%KernelOne__arg_accTmp31.sroa.0.0.copyload = load i64, ptr %KernelOne__arg_accTmp3, align 8
%add.ptr.j2 = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %KernelOne__arg_accTmp, i64 %KernelOne__arg_accTmp31.sroa.0.0.copyload
%add.ptr.i37.i = getelementptr inbounds i32, ptr addrspace(1) %KernelOne__arg_accIn, i64 %KernelOne__arg_accIn62.sroa.0.0.copyload
%add.ptr.i43.i = getelementptr inbounds i8, ptr addrspace(1) %KernelOne__arg_accTmp27, i64 %KernelOne__arg_accTmp2103.sroa.0.0.copyload
%0 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%conv.i1.j7 = sext i32 %0 to i64
%1 = tail call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%conv.i3.j7 = sext i32 %1 to i64
%mul.j7 = mul nsw i64 %conv.i3.j7, %conv.i1.j7
%2 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%conv.i2.j7 = sext i32 %2 to i64
%add.j7 = add nsw i64 %mul.j7, %conv.i2.j7
%3 = tail call ptr @llvm.nvvm.implicit.offset()
%4 = load i32, ptr %3, align 4
%conv.j8 = zext i32 %4 to i64
%add4.j7 = add nsw i64 %add.j7, %conv.j8
%mul.j2 = mul nuw nsw i64 %add4.j7, 3
%add.j2 = add nuw nsw i64 %mul.j2, 1
%arrayidx.j2 = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i37.i, i64 %add.j2
%5 = load i32, ptr addrspace(1) %arrayidx.j2, align 4
%arrayidx.i55.i = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %add.ptr.j2, i64 %add.j2
%arrayidx.j3 = getelementptr inbounds i8, ptr addrspace(1) %arrayidx.i55.i, i64 20
store i32 %5, ptr addrspace(1) %arrayidx.j3, align 4
%conv.j2 = trunc i32 %5 to i8
%arrayidx.i73.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i43.i, i64 %add.j2
store i8 %conv.j2, ptr addrspace(1) %arrayidx.i73.i, align 1
%KernelTwo__arg_accOut34.sroa.0.0.copyload = load i64, ptr %KernelTwo__arg_accOut3, align 8
%add.ptr.i.i7 = getelementptr inbounds i32, ptr addrspace(1) %KernelTwo__arg_accOut, i64 %KernelTwo__arg_accOut34.sroa.0.0.copyload
%6 = load i32, ptr %3, align 4
%conv.j7.i13 = zext i32 %6 to i64
%add4.j6.i14 = add nsw i64 %add.j7, %conv.j7.i13
%mul.i.i16 = mul nuw nsw i64 %add4.j6.i14, 3
%add.i45.i = add nuw nsw i64 %mul.i.i16, 1
%arrayidx.i.i17 = getelementptr inbounds %struct.MyStruct, ptr addrspace(1) %add.ptr.j2, i64 %add.i45.i
%arrayidx.j2.i19 = getelementptr inbounds i8, ptr addrspace(1) %arrayidx.i.i17, i64 20
%7 = load i32, ptr addrspace(1) %arrayidx.j2.i19, align 4
%arrayidx.i55.i20 = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i43.i, i64 %add.i45.i
%8 = load i8, ptr addrspace(1) %arrayidx.i55.i20, align 1
%conv.i.i22 = sext i8 %8 to i32
%add.i.i23 = add nsw i32 %7, %conv.i.i22
%arrayidx.i64.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i7, i64 %add.i45.i
store i32 %add.i.i23, ptr addrspace(1) %arrayidx.i64.i, align 4
ret void
}

attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #1 = { nofree nosync nounwind speculatable memory(none) }
attributes #3 = { nofree nosync nounwind memory(read, argmem: readwrite, inaccessiblemem: write) "frame-pointer"="all" "target-cpu"="sm_80" "target-features"="+ptx82,+sm_80" "uniform-work-group-size"="true" }

!nvvm.annotations = !{!10}

!10 = !{ptr @fused_0, !"kernel", i32 1}
!13 = !{!"private", !"none", !"none", !"none", !"private", !"none", !"none", !"none"}
!14 = !{i64 3, !"", !"", !"", i64 3, !"", !"", !""}
!15 = !{i64 32, !"", !"", !"", i64 1, !"", !"", !""}
Loading

0 comments on commit a90aaa7

Please sign in to comment.