Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Added SVE APIs CreateMaskForFirstActiveElement and CreateMaskForNextActiveElement #104002

Merged
merged 13 commits into from
Jun 29, 2024
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_CreateMaskForFirstActiveElement:
case NI_Sve_CreateMaskForNextActiveElement:
case NI_Sve_GetActiveElementCount:
case NI_Sve_TestAnyTrue:
case NI_Sve_TestFirstTrue:
Expand Down
33 changes: 32 additions & 1 deletion src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,23 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
assert(!node->IsEmbMaskOp());
if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id))
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
if (isRMW)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you confirm in what cases we come at this code path vs. the one at the end of this file?

Copy link
Contributor Author

@TIHan TIHan Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreateMaskForNextActiveElement will come there, and any table-driven HW intrinsic that has two parameters with an explicit masked operation. But CreateMaskForNextActiveElement is the first RMW to reach this path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cannot NI_Sve_CreateMaskForFirstActiveElement handled here itself after you remove the SpecialCodegen flag for it. They both need INS_OPTS_SCALABLE_B as opt. So if this code can work for
NI_Sve_CreateMaskForNextActiveElement, wondering why can't it work for NI_Sve_CreateMaskForFirstActiveElement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FirstActiveElement will as well, sorry I didn't include it in my comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see now. Yes, we shouldn't need the SpecialCodeGen for FirstActiveElement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I actually tried it out. So 'pnext' is PNEXT <Pdn>.<T>, <Pv>, <Pdn>.<T> while 'pfirst' is PFIRST <Pdn>.B, <Pg>, <Pdn>.B. So, yea, we actually need to do SpecialCodeGen for FirstActiveElement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hhm, not sure I am still following. Since FirstActiveElement code that you have is:

                assert(isRMW);
                assert(HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id));

                if (targetReg != op2Reg)
                {
                    assert(targetReg != op1Reg);
                    GetEmitter()->emitIns_Mov(INS_sve_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
                }

                GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, INS_OPTS_SCALABLE_B);
                break;

If we delete that, it will come on line 846, check it is isRMW and go inside if block and call emitIns_R_R(). Unless I am missing something major here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the line you are talking about, when the opt gets passed to emitIns_R_R, the opt is not guaranteed to be INS_OPTS_SCALABLE_B

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's where the confusion was on my part. I was looking at the unit tests and thought they both take just INS_OPTS_SCALABLE_B, but looking again at https://docsmirror.github.io/A64/2023-06/pnext_p_p_p.html, it can get other values as well. Makes sense now. Thanks!
 

theEmitter->emitIns_R_R(INS_sve_pnext, EA_SCALABLE, REG_P0, REG_P15,
INS_OPTS_SCALABLE_B); // PNEXT <Pdn>.<T>, <Pv>, <Pdn>.<T>

theEmitter->emitIns_R_R(INS_sve_pfirst, EA_SCALABLE, REG_P0, REG_P15,
INS_OPTS_SCALABLE_B); // PFIRST <Pdn>.B, <Pg>, <Pdn>.B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the unit tests and thought they both take just INS_OPTS_SCALABLE_B

I did the same thing. :)

{
if (targetReg != op2Reg)
{
assert(targetReg != op1Reg);

GetEmitter()->emitIns_Mov(ins_Move_Extend(intrin.op2->TypeGet(), false),
emitTypeSize(node), targetReg, op2Reg,
/* canSkip */ true);
}

GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
}
else
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
}
}
else
{
Expand Down Expand Up @@ -2187,6 +2203,21 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_CreateMaskForFirstActiveElement:
{
assert(isRMW);
assert(HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id));

if (targetReg != op2Reg)
{
assert(targetReg != op1Reg);
GetEmitter()->emitIns_Mov(INS_sve_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
}

GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, INS_OPTS_SCALABLE_B);
break;
}

default:
unreached();
}
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ HARDWARE_INTRINSIC(Sve, CreateFalseMaskSingle,
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt16, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt32, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt64, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateMaskForFirstActiveElement, -1, 2, true, {INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need SpecialCodeGen for this. They should be handled the same way CreateMaskForNextActiveElement is handled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to in order to pass INS_OPTS_SCALABLE_B as part of emitting the instruction.

HARDWARE_INTRINSIC(Sve, CreateMaskForNextActiveElement, -1, 2, true, {INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskByte, -1, 1, false, {INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskDouble, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
Expand Down
1 change: 0 additions & 1 deletion src/coreclr/jit/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,6 @@ instruction CodeGen::ins_Move_Extend(var_types srcType, bool srcInReg)
#if defined(TARGET_XARCH)
return INS_kmovq_msk;
#elif defined(TARGET_ARM64)
unreached(); // TODO-SVE: This needs testing
TIHan marked this conversation as resolved.
Show resolved Hide resolved
return INS_sve_mov;
#endif
}
Expand Down
9 changes: 8 additions & 1 deletion src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,14 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

srcCount += BuildOperandUses(intrin.op1, predMask);
if (tgtPrefOp2)
{
srcCount += BuildDelayFreeUses(intrin.op1, intrin.op2, predMask);
}
else
{
srcCount += BuildOperandUses(intrin.op1, predMask);
}
}
}
else if (intrinsicTree->OperIsMemoryLoadOrStore())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,79 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateFalseMaskUInt64() { throw new PlatformNotSupportedException(); }


/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
/// PNEXT Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
/// PNEXT Ptied.H, Pg, Ptied.H
/// </summary>
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
/// PNEXT Ptied.S, Pg, Ptied.S
/// </summary>
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
/// PNEXT Ptied.D, Pg, Ptied.D
/// </summary>
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }


/// CreateTrueMaskByte : Set predicate elements to true

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,79 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateFalseMaskUInt64() => CreateFalseMaskUInt64();


/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
/// PNEXT Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
/// PNEXT Ptied.H, Pg, Ptied.H
/// </summary>
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
/// PNEXT Ptied.S, Pg, Ptied.S
/// </summary>
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
/// PNEXT Ptied.D, Pg, Ptied.D
/// </summary>
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);


/// CreateTrueMaskByte : Set predicate elements to true

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4335,6 +4335,20 @@ internal Arm64() { }
public static System.Numerics.Vector<ushort> CreateFalseMaskUInt16() { throw null; }
public static System.Numerics.Vector<uint> CreateFalseMaskUInt32() { throw null; }
public static System.Numerics.Vector<ulong> CreateFalseMaskUInt64() { throw null; }

public static unsafe System.Numerics.Vector<byte> CreateMaskForFirstActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<short> CreateMaskForFirstActiveElement(System.Numerics.Vector<short> mask, System.Numerics.Vector<short> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<int> CreateMaskForFirstActiveElement(System.Numerics.Vector<int> mask, System.Numerics.Vector<int> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<long> CreateMaskForFirstActiveElement(System.Numerics.Vector<long> mask, System.Numerics.Vector<long> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<sbyte> CreateMaskForFirstActiveElement(System.Numerics.Vector<sbyte> mask, System.Numerics.Vector<sbyte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ushort> CreateMaskForFirstActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<uint> CreateMaskForFirstActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ulong> CreateMaskForFirstActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<byte> CreateMaskForNextActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ushort> CreateMaskForNextActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<uint> CreateMaskForNextActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ulong> CreateMaskForNextActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }

public static System.Numerics.Vector<byte> CreateTrueMaskByte([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<double> CreateTrueMaskDouble([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<short> CreateTrueMaskInt16([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
Expand Down
Loading
Loading