Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,7 @@ struct GenTree
inline bool IsVectorZero() const;
inline bool IsVectorCreate() const;
inline bool IsVectorAllBitsSet() const;
inline bool IsMaskAllBitsSet() const;
inline bool IsVectorConst();

inline uint64_t GetIntegralVectorConstElement(size_t index, var_types simdBaseType);
Expand Down Expand Up @@ -9238,6 +9239,32 @@ inline bool GenTree::IsVectorAllBitsSet() const
return false;
}

inline bool GenTree::IsMaskAllBitsSet() const
{
#ifdef TARGET_ARM64
static_assert_no_msg(AreContiguous(NI_Sve_CreateTrueMaskByte, NI_Sve_CreateTrueMaskDouble,
NI_Sve_CreateTrueMaskInt16, NI_Sve_CreateTrueMaskInt32,
NI_Sve_CreateTrueMaskInt64, NI_Sve_CreateTrueMaskSByte,
NI_Sve_CreateTrueMaskSingle, NI_Sve_CreateTrueMaskUInt16,
NI_Sve_CreateTrueMaskUInt32, NI_Sve_CreateTrueMaskUInt64));

if (OperIsHWIntrinsic())
{
NamedIntrinsic id = AsHWIntrinsic()->GetHWIntrinsicId();
if (id == NI_Sve_ConvertMaskToVector)
{
GenTree* op1 = AsHWIntrinsic()->Op(1);
assert(op1->OperIsHWIntrinsic());
id = op1->AsHWIntrinsic()->GetHWIntrinsicId();
}
return ((id == NI_Sve_CreateTrueMaskAll) ||
((id >= NI_Sve_CreateTrueMaskByte) && (id <= NI_Sve_CreateTrueMaskUInt64)));
}

#endif
return false;
}

//-------------------------------------------------------------------
// IsVectorConst: returns true if this node is a HWIntrinsic that represents a constant.
//
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
if (intrinsic == NI_Sve_ConditionalSelect)
{
if (op1->IsVectorAllBitsSet())
if (op1->IsVectorAllBitsSet() || op1->IsMaskAllBitsSet())
{
return retNode->AsHWIntrinsic()->Op(2);
}
Expand Down
48 changes: 46 additions & 2 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// Handle case where op2 is operation that needs embedded mask
GenTree* op2 = intrin.op2;
assert(intrin.id == NI_Sve_ConditionalSelect);
assert(op2->isContained());
assert(op2->OperIsHWIntrinsic());
assert(op2->isContained());

// Get the registers and intrinsics that needs embedded mask
const HWIntrinsic intrinEmbMask(op2->AsHWIntrinsic());
Expand Down Expand Up @@ -439,10 +439,54 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
case 1:
assert(!instrIsRMW);

if (targetReg != falseReg)
{
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg);
// If targetReg is not the same as `falseReg` then need to move
// the `falseReg` to `targetReg`.

if (intrin.op3->isContained())
{
assert(intrin.op3->IsVectorZero());
if (intrin.op1->isContained())
{
// We already skip importing ConditionalSelect if op1 == trueAll, however
// if we still see it here, it is because we wrapped the predicated instruction
// inside ConditionalSelect.
// As such, no need to move the `falseReg` to `targetReg`
// because the predicated instruction will eventually set it.
assert(intrin.op1->IsMaskAllBitsSet());
}
else
{
// If falseValue is zero, just zero out those lanes of targetReg using `movprfx`
// and /Z
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, targetReg,
opt);
}
}
else if (targetReg == embMaskOp1Reg)
{
// target != falseValue, but we do not want to overwrite target with `embMaskOp1Reg`.
// We will first do the predicate operation and then do conditionalSelect inactive
// elements from falseValue

// We cannot use use `movprfx` here to move falseReg to targetReg because that will
// overwrite the value of embMaskOp1Reg which is present in targetReg.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);

GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
falseReg, opt, INS_SCALABLE_OPTS_UNPREDICATED);
break;
}
else
{
// At this point, target != embMaskOp1Reg != falseReg, so just go ahead
// and move the falseReg unpredicated into targetReg.
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg);
}
}

GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
break;

Expand Down