-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Description
Background and motivation
Consider this function to Multiply two float arrays together and then sum the result.
The three casts to Vector<float> exist because there are no float versions CreateWhileLessThanMaskX() and TestFirstTrue()
public static unsafe float fmla(ref float* a, ref float* b, int length)
{
Vector<float> res = Vector<float>.Zero;
Vector<uint> ptrue = Sve.CreateTrueMaskUInt32();
Vector<uint> ploop;
for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMask32Bit(i, length)); i+= (int)Sve.Count32BitElements())
{
Vector<float> a_vec = Sve.LoadVector((Vector<float>)ploop, a+i);
Vector<float> b_vec = Sve.LoadVector((Vector<float>)ploop, b+i);
res = Sve.ConditionalSelect((Vector<float>)ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
}
return Sve.AddAcross(res).ToScalar();
}The following code looks much nicer and is easier to logically parse:
public static unsafe float fmla(ref float* a, ref float* b, int length)
{
Vector<float> res = Vector<float>.Zero;
Vector<float> ptrue = Sve.CreateTrueMaskFloat();
Vector<float> ploop;
for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMaskFloat(i, length)); i+= (int)Sve.Count32BitElements())
{
Vector<float> a_vec = Sve.LoadVector(ploop, a+i);
Vector<float> b_vec = Sve.LoadVector(ploop, b+i);
res = Sve.ConditionalSelect(ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
}
return Sve.AddAcross(res).ToScalar();
}The same would then apply for other SVE APIs operating on a mask.
Also consider:
public static unsafe long multiplyAdd(ref short* a, ref short* b, int length)
{
Vector<short> res = Vector<short>.Zero;
Vector<short> ploop;
for (int i = 0;
Sve.TestFirstTrue(Sve.CreateTrueMaskInt16(), ploop = (Vector<short>)Sve.CreateWhileLessThanMask16Bit(i, length));
i+= (int)Sve.Count16BitElements())
{
Vector<short> a_vec = Sve.LoadVector((Vector<short>)ploop, a+i);
Vector<short> b_vec = Sve.LoadVector((Vector<short>)ploop, b+i);
res = Sve.ConditionalSelect((Vector<short>)ploop, Sve.MultiplyAdd(res, a_vec, b_vec), res);
}
return Sve.AddAcross(res).ToScalar();
}For the for loop we need a 16bit whilelt mask. The only way to create this is via CreateWhileLessThanMask16Bit(), but this returns a Vector<ushort>.
It needs to be a Vector<short> so that it can be used in the conditionalSelect().
The casting to Vector<short> is a little confusing.
I suggest we add signed versions of CreateWhileLessThanMask()
API Proposal
Using the same T syntax as other SVE proposals. These are all extensions of existing API methods
APIs to add:
namespace System.Runtime.Intrinsics.Arm;
public partial class Sve
{
/// T: float, double
public static unsafe Vector<T> CreateBreakAfterMask(Vector<T> totalMask, Vector<T> fromMask); // BRKA // predicated
/// T: float, double
public static unsafe Vector<T> CreateBreakAfterPropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPA
/// T: float, double
public static unsafe Vector<T> CreateBreakBeforeMask(Vector<T> totalMask, Vector<T> fromMask); // BRKB // predicated
/// T: float, double
public static unsafe Vector<T> CreateBreakBeforePropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPB
/// T: float, double
public static unsafe Vector<T> CreateBreakPropagateMask(Vector<T> totalMask, Vector<T> fromMask); // BRKN // predicated
/// T: float, double
public static unsafe Vector<T> CreateMaskForFirstActiveElement(Vector<T> totalMask, Vector<T> fromMask); // PFIRST
/// T: float, double
public static unsafe bool TestAnyTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST
/// T: float, double
public static unsafe bool TestFirstTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST
/// T: float, double
public static unsafe bool TestLastTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST
public static Vector<float> CreateWhileLessThanMaskSingle(int left, int right);
public static Vector<float> CreateWhileLessThanMaskSingle(long left, long right);
public static Vector<float> CreateWhileLessThanMaskSingle(uint left, uint right);
public static Vector<float> CreateWhileLessThanMaskSingle(ulong left, ulong right);
public static Vector<double> CreateWhileLessThanMaskDouble(int left, int right);
public static Vector<double> CreateWhileLessThanMaskDouble(long left, long right);
public static Vector<double> CreateWhileLessThanMaskDouble(uint left, uint right);
public static Vector<double> CreateWhileLessThanMaskDouble(ulong left, ulong right);
public static unsafe Vector<short> CreateWhileLessThanMaskInt16(int left, int right); // WHILELT
public static unsafe Vector<short> CreateWhileLessThanMaskInt16(long left, long right); // WHILELT
public static unsafe Vector<short> CreateWhileLessThanMaskInt16(uint left, uint right); // WHILELO
public static unsafe Vector<short> CreateWhileLessThanMaskInt16(ulong left, ulong right); // WHILELO
public static unsafe Vector<int> CreateWhileLessThanMaskInt32(int left, int right); // WHILELT
public static unsafe Vector<int> CreateWhileLessThanMaskInt32(long left, long right); // WHILELT
public static unsafe Vector<int> CreateWhileLessThanMaskInt32(uint left, uint right); // WHILELO
public static unsafe Vector<int> CreateWhileLessThanMaskInt32(ulong left, ulong right); // WHILELO
public static unsafe Vector<long> CreateWhileLessThanMaskInt64(int left, int right); // WHILELT
public static unsafe Vector<long> CreateWhileLessThanMaskInt64(long left, long right); // WHILELT
public static unsafe Vector<long> CreateWhileLessThanMaskInt64(uint left, uint right); // WHILELO
public static unsafe Vector<long> CreateWhileLessThanMaskInt64(ulong left, ulong right); // WHILELO
public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(int left, int right); // WHILELT
public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(long left, long right); // WHILELT
public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(uint left, uint right); // WHILELO
public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(ulong left, ulong right); // WHILELO
public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(int left, int right); // WHILELT
public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(long left, long right); // WHILELT
public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(uint left, uint right); // WHILELO
public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(ulong left, ulong right); // WHILELO
public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(int left, int right); // WHILELT
public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(long left, long right); // WHILELT
public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(uint left, uint right); // WHILELO
public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(ulong left, ulong right); // WHILELO
public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(int left, int right); // WHILELT
public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(long left, long right); // WHILELT
public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(uint left, uint right); // WHILELO
public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(ulong left, ulong right); // WHILELO
public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(int left, int right); // WHILELT
public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(long left, long right); // WHILELT
public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(uint left, uint right); // WHILELO
public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(ulong left, ulong right); // WHILELO
public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(int left, int right);
public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(long left, long right);
public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(uint left, uint right);
public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(ulong left, ulong right);
public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(int left, int right);
public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(long left, long right);
public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(uint left, uint right);
public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(ulong left, ulong right);
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(int left, int right); // WHILELE
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(long left, long right); // WHILELE
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(uint left, uint right); // WHILELS
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(ulong left, ulong right); // WHILELS
public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(int left, int right); // WHILELE
public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(long left, long right); // WHILELE
public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(uint left, uint right); // WHILELS
public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(ulong left, ulong right); // WHILELS
public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(int left, int right); // WHILELE
public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(long left, long right); // WHILELE
public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(uint left, uint right); // WHILELS
public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(ulong left, ulong right); // WHILELS
public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(int left, int right); // WHILELE
public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(long left, long right); // WHILELE
public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(uint left, uint right); // WHILELS
public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(ulong left, ulong right); // WHILELS
public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(int left, int right); // WHILELE
public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(long left, long right); // WHILELE
public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(uint left, uint right); // WHILELS
public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(ulong left, ulong right); // WHILELS
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(int left, int right); // WHILELE
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(long left, long right); // WHILELE
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(uint left, uint right); // WHILELS
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(ulong left, ulong right); // WHILELS
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(int left, int right); // WHILELE
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(long left, long right); // WHILELE
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(uint left, uint right); // WHILELS
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(ulong left, ulong right); // WHILELS
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(int left, int right); // WHILELE
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(long left, long right); // WHILELE
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(uint left, uint right); // WHILELS
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(ulong left, ulong right); // WHILELS
public static unsafe Vector<T> GetFfrSingle(); // RDFFR // predicated
public static unsafe Vector<T> GetFfrDouble(); // RDFFR // predicated
/// T: float, double
public static unsafe void SetFfr(Vector<T> value); // WRFFR
}APIs to remove:
namespace System.Runtime.Intrinsics.Arm;
public partial class Sve
{
// Change of existing APIs to return signed vector instead of unsigned
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(int left, int right); // WHILELT
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(long left, long right); // WHILELT
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(uint left, uint right); // WHILELO
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(ulong left, ulong right); // WHILELO
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(int left, int right); // WHILELT
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(long left, long right); // WHILELT
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(uint left, uint right); // WHILELO
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(ulong left, ulong right); // WHILELO
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(int left, int right); // WHILELT
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(long left, long right); // WHILELT
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(uint left, uint right); // WHILELO
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(ulong left, ulong right); // WHILELO
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(int left, int right); // WHILELT
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(long left, long right); // WHILELT
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(uint left, uint right); // WHILELO
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(ulong left, ulong right); // WHILELO
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(int left, int right); // WHILELE
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(long left, long right); // WHILELE
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right); // WHILELS
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right); // WHILELS
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(int left, int right); // WHILELE
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(long left, long right); // WHILELE
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right); // WHILELS
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right); // WHILELS
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(int left, int right); // WHILELE
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(long left, long right); // WHILELE
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right); // WHILELS
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right); // WHILELS
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(int left, int right); // WHILELE
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(long left, long right); // WHILELE
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right); // WHILELS
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right); // WHILELS
}