Skip to content

Commit

Permalink
JIT: Added Sve.CreateBreakPropagateMask (#104704)
Browse files Browse the repository at this point in the history
* Added Sve.CreateBreakPropagateMask

* Added assert

* Fixed targetReg and maskReg using the same register

* Minor rename

* Formatting

* No need to use predMask

* More formatting

* Add additional comment

* Feedback

* fix lsra

* fix build error

---------

Co-authored-by: Kunal Pathak <[email protected]>
  • Loading branch information
TIHan and kunalspathak authored Jul 17, 2024
1 parent fa77959 commit 5677d92
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 38 deletions.
47 changes: 27 additions & 20 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2286,10 +2286,16 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}

#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_ARM64)
auto convertToMaskIfNeeded = [&](GenTree*& op) {
if (!varTypeIsMask(op))
{
op = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op, simdBaseJitType, simdSize);
}
};

if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrinsic))
{
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);

switch (intrinsic)
{
Expand All @@ -2304,14 +2310,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case NI_Sve_TestFirstTrue:
case NI_Sve_TestLastTrue:
{
GenTree* op2 = retNode->AsHWIntrinsic()->Op(2);

// HWInstrinsic requires a mask for op2
if (!varTypeIsMask(op2))
{
retNode->AsHWIntrinsic()->Op(2) =
gtNewSimdCvtVectorToMaskNode(TYP_MASK, op2, simdBaseJitType, simdSize);
}
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(2));
break;
}

Expand All @@ -2324,26 +2324,17 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case NI_Sve_CreateBreakAfterPropagateMask:
case NI_Sve_CreateBreakBeforePropagateMask:
{
GenTree* op3 = retNode->AsHWIntrinsic()->Op(3);

// HWInstrinsic requires a mask for op3
if (!varTypeIsMask(op3))
{
retNode->AsHWIntrinsic()->Op(3) =
gtNewSimdCvtVectorToMaskNode(TYP_MASK, op3, simdBaseJitType, simdSize);
}
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(3));
break;
}

default:
break;
}

if (!varTypeIsMask(op1))
{
// Op1 input is a vector. HWInstrinsic requires a mask.
retNode->AsHWIntrinsic()->Op(1) = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op1, simdBaseJitType, simdSize);
}
// HWInstrinsic requires a mask for op1
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(1));

if (HWIntrinsicInfo::IsMultiReg(intrinsic))
{
Expand All @@ -2354,6 +2345,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
}

if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsic))
{
switch (intrinsic)
{
case NI_Sve_CreateBreakPropagateMask:
{
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(1));
convertToMaskIfNeeded(retNode->AsHWIntrinsic()->Op(2));
break;
}

default:
break;
}
}

if (retType != nodeRetType)
{
// HWInstrinsic returns a mask, but all returns must be vectors, so convert mask to vector.
Expand Down
34 changes: 26 additions & 8 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;
bool hasShift = false;

insOpts embOpt = opt;
switch (intrinEmbMask.id)
{
case NI_Sve_ShiftLeftLogical:
Expand All @@ -689,6 +690,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
hasShift = true;
break;

case NI_Sve_CreateBreakPropagateMask:
embOpt = INS_OPTS_SCALABLE_B;
break;

default:
break;
}
Expand All @@ -699,13 +704,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op2, op2->AsHWIntrinsic());
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(), opt,
sopt);
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(),
embOpt, sopt);
}
}
else
{
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, opt, sopt);
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, embOpt, sopt);
}
};

Expand All @@ -714,12 +719,25 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
// destination using /Z.

assert(targetReg != embMaskOp2Reg);
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
switch (intrinEmbMask.id)
{
case NI_Sve_CreateBreakPropagateMask:
assert(targetReg != embMaskOp1Reg);
GetEmitter()->emitIns_Mov(INS_sve_mov, emitSize, targetReg, embMaskOp2Reg,
/* canSkip */ true);
emitInsHelper(targetReg, maskReg, embMaskOp1Reg);
break;

// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
default:
assert(targetReg != embMaskOp2Reg);
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg,
embMaskOp1Reg, opt);

// Finally, perform the actual "predicated" operation so that `targetReg` is the first
// operand and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
break;
}
}
else if (targetReg != falseReg)
{
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ HARDWARE_INTRINSIC(Sve, CreateBreakAfterMask,
HARDWARE_INTRINSIC(Sve, CreateBreakAfterPropagateMask, -1, 3, true, {INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_sve_brkpa, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakBeforeMask, -1, 2, true, {INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_sve_brkb, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakBeforePropagateMask, -1, 3, true, {INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_sve_brkpb, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, CreateBreakPropagateMask, -1, -1, false, {INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_sve_brkn, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskByte, -1, 0, false, {INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskDouble, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskInt16, -1, 0, false, {INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
Expand Down
40 changes: 31 additions & 9 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
}
else
{
SingleTypeRegSet predMask = RBM_ALLMASK.GetPredicateRegSet();
bool tgtPrefEmbOp2 = false;
SingleTypeRegSet predMask = RBM_ALLMASK.GetPredicateRegSet();
if (intrin.id == NI_Sve_ConditionalSelect)
{
// If this is conditional select, make sure to check the embedded
Expand All @@ -1658,16 +1659,26 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

// Special-case, CreateBreakPropagateMask's op2 is the RMW node.
if (intrinEmb.id == NI_Sve_CreateBreakPropagateMask)
{
assert(embOp2Node->isRMWHWIntrinsic(compiler));
assert(!tgtPrefOp1);
assert(!tgtPrefOp2);
tgtPrefEmbOp2 = true;
}
}
}
else if (HWIntrinsicInfo::IsLowMaskedOperation(intrin.id))
{
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

if (tgtPrefOp2)
if (tgtPrefOp2 || tgtPrefEmbOp2)
{
srcCount += BuildDelayFreeUses(intrin.op1, intrin.op2, predMask);
assert(!tgtPrefOp1);
srcCount += BuildDelayFreeUses(intrin.op1, nullptr, predMask);
}
else
{
Expand Down Expand Up @@ -1983,15 +1994,26 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
break;
}

tgtPrefUse = BuildUse(embOp2Node->Op(1));
srcCount += 1;

for (size_t argNum = 2; argNum <= numArgs; argNum++)
size_t prefUseOpNum = 1;
if (intrinEmb.id == NI_Sve_CreateBreakPropagateMask)
{
prefUseOpNum = 2;
}
GenTree* prefUseNode = embOp2Node->Op(prefUseOpNum);
for (size_t argNum = 1; argNum <= numArgs; argNum++)
{
srcCount += BuildDelayFreeUses(embOp2Node->Op(argNum), embOp2Node->Op(1));
if (argNum == prefUseOpNum)
{
tgtPrefUse = BuildUse(prefUseNode);
srcCount += 1;
}
else
{
srcCount += BuildDelayFreeUses(embOp2Node->Op(argNum), prefUseNode);
}
}

srcCount += BuildDelayFreeUses(intrin.op3, embOp2Node->Op(1));
srcCount += BuildDelayFreeUses(intrin.op3, prefUseNode);
}
}
else if (intrin.op2 != nullptr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,55 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateBreakBeforePropagateMask(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }


/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<sbyte> CreateBreakPropagateMask(Vector<sbyte> totalMask, Vector<sbyte> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<short> CreateBreakPropagateMask(Vector<short> totalMask, Vector<short> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<int> CreateBreakPropagateMask(Vector<int> totalMask, Vector<int> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<long> CreateBreakPropagateMask(Vector<long> totalMask, Vector<long> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<byte> CreateBreakPropagateMask(Vector<byte> totalMask, Vector<byte> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ushort> CreateBreakPropagateMask(Vector<ushort> totalMask, Vector<ushort> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<uint> CreateBreakPropagateMask(Vector<uint> totalMask, Vector<uint> fromMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ulong> CreateBreakPropagateMask(Vector<ulong> totalMask, Vector<ulong> fromMask) { throw new PlatformNotSupportedException(); }


/// Set all predicate elements to false

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,55 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateBreakBeforePropagateMask(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) => CreateBreakBeforePropagateMask(mask, left, right);


/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<sbyte> CreateBreakPropagateMask(Vector<sbyte> totalMask, Vector<sbyte> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<short> CreateBreakPropagateMask(Vector<short> totalMask, Vector<short> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<int> CreateBreakPropagateMask(Vector<int> totalMask, Vector<int> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<long> CreateBreakPropagateMask(Vector<long> totalMask, Vector<long> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<byte> CreateBreakPropagateMask(Vector<byte> totalMask, Vector<byte> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ushort> CreateBreakPropagateMask(Vector<ushort> totalMask, Vector<ushort> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<uint> CreateBreakPropagateMask(Vector<uint> totalMask, Vector<uint> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);

/// <summary>
/// svbool_t svbrkn[_b]_z(svbool_t pg, svbool_t op1, svbool_t op2)
/// BRKN Ptied2.B, Pg/Z, Pop1.B, Ptied2.B
/// </summary>
public static unsafe Vector<ulong> CreateBreakPropagateMask(Vector<ulong> totalMask, Vector<ulong> fromMask) => CreateBreakPropagateMask(totalMask, fromMask);


/// Set all predicate elements to false

/// <summary>
Expand Down
Loading

0 comments on commit 5677d92

Please sign in to comment.