Skip to content

[API Proposal]: ARM64 SVE: Add additional types for the mask APIs #108233

@a74nh

Description

@a74nh

Background and motivation

Consider this function to Multiply two float arrays together and then sum the result.

The three casts to Vector<float> exist because there are no float versions CreateWhileLessThanMaskX() and TestFirstTrue()

    public static unsafe float fmla(ref float* a, ref float* b, int length)
    {
      Vector<float> res = Vector<float>.Zero;
      Vector<uint> ptrue = Sve.CreateTrueMaskUInt32();
      Vector<uint> ploop;

      for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMask32Bit(i, length)); i+= (int)Sve.Count32BitElements())
      {
        Vector<float> a_vec = Sve.LoadVector((Vector<float>)ploop, a+i);
        Vector<float> b_vec = Sve.LoadVector((Vector<float>)ploop, b+i);
        res = Sve.ConditionalSelect((Vector<float>)ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
      }

      return Sve.AddAcross(res).ToScalar();
    }

The following code looks much nicer and is easier to logically parse:

    public static unsafe float fmla(ref float* a, ref float* b, int length)
    {
      Vector<float> res = Vector<float>.Zero;
      Vector<float> ptrue = Sve.CreateTrueMaskFloat();
      Vector<float> ploop;

      for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMaskFloat(i, length)); i+= (int)Sve.Count32BitElements())
      {
        Vector<float> a_vec = Sve.LoadVector(ploop, a+i);
        Vector<float> b_vec = Sve.LoadVector(ploop, b+i);
        res = Sve.ConditionalSelect(ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
      }

      return Sve.AddAcross(res).ToScalar();
    }

The same would then apply for other SVE APIs operating on a mask.

Also consider:

public static unsafe long multiplyAdd(ref short* a, ref short* b, int length)
{
  Vector<short> res = Vector<short>.Zero;
  Vector<short> ploop;

  for (int i = 0;
       Sve.TestFirstTrue(Sve.CreateTrueMaskInt16(), ploop = (Vector<short>)Sve.CreateWhileLessThanMask16Bit(i, length));
       i+= (int)Sve.Count16BitElements())
  {
    Vector<short> a_vec = Sve.LoadVector((Vector<short>)ploop, a+i);
    Vector<short> b_vec = Sve.LoadVector((Vector<short>)ploop, b+i);
    res = Sve.ConditionalSelect((Vector<short>)ploop, Sve.MultiplyAdd(res, a_vec, b_vec), res);
  }

  return Sve.AddAcross(res).ToScalar();
}

For the for loop we need a 16bit whilelt mask. The only way to create this is via CreateWhileLessThanMask16Bit(), but this returns a Vector<ushort>.
It needs to be a Vector<short> so that it can be used in the conditionalSelect().

The casting to Vector<short> is a little confusing.

I suggest we add signed versions of CreateWhileLessThanMask()

API Proposal

Using the same T syntax as other SVE proposals. These are all extensions of existing API methods

APIs to add:

namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{

  /// T: float, double
  public static unsafe Vector<T> CreateBreakAfterMask(Vector<T> totalMask, Vector<T> fromMask); // BRKA // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateBreakAfterPropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPA

  /// T: float, double
  public static unsafe Vector<T> CreateBreakBeforeMask(Vector<T> totalMask, Vector<T> fromMask); // BRKB // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateBreakBeforePropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPB

  /// T: float, double
  public static unsafe Vector<T> CreateBreakPropagateMask(Vector<T> totalMask, Vector<T> fromMask); // BRKN // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateMaskForFirstActiveElement(Vector<T> totalMask, Vector<T> fromMask); // PFIRST

  /// T: float, double
  public static unsafe bool TestAnyTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  /// T: float, double
  public static unsafe bool TestFirstTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  /// T: float, double
  public static unsafe bool TestLastTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  public static Vector<float> CreateWhileLessThanMaskSingle(int left, int right);
  public static Vector<float> CreateWhileLessThanMaskSingle(long left, long right);
  public static Vector<float> CreateWhileLessThanMaskSingle(uint left, uint right);
  public static Vector<float> CreateWhileLessThanMaskSingle(ulong left, ulong right);

  public static Vector<double> CreateWhileLessThanMaskDouble(int left, int right);
  public static Vector<double> CreateWhileLessThanMaskDouble(long left, long right);
  public static Vector<double> CreateWhileLessThanMaskDouble(uint left, uint right);
  public static Vector<double> CreateWhileLessThanMaskDouble(ulong left, ulong right);

  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(int left, int right); // WHILELT
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(long left, long right); // WHILELT
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(uint left, uint right); // WHILELO
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(ulong left, ulong right); // WHILELO

  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(int left, int right); // WHILELT
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(long left, long right); // WHILELT
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(uint left, uint right); // WHILELO
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(ulong left, ulong right); // WHILELO

  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(int left, int right); // WHILELT
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(long left, long right); // WHILELT
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(uint left, uint right); // WHILELO
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(ulong left, ulong right); // WHILELO

  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(int left, int right); // WHILELT
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(long left, long right); // WHILELT
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(uint left, uint right); // WHILELO
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(int left, int right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(long left, long right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(uint left, uint right); // WHILELO
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(ulong left, ulong right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(int left, int right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(long left, long right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(uint left, uint right); // WHILELO
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(int left, int right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(long left, long right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(uint left, uint right); // WHILELO
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(ulong left, ulong right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(int left, int right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(long left, long right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(uint left, uint right); // WHILELO
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(ulong left, ulong right); // WHILELO

  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(int left, int right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(long left, long right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(uint left, uint right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(ulong left, ulong right);

  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(int left, int right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(long left, long right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(uint left, uint right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(ulong left, ulong right);

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(int left, int right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(long left, long right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(uint left, uint right); // WHILELS
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(ulong left, ulong right); // WHILELS

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(int left, int right); // WHILELE
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(long left, long right); // WHILELE
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(uint left, uint right); // WHILELS
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(ulong left, ulong right); // WHILELS

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(int left, int right); // WHILELE
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(long left, long right); // WHILELE
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(uint left, uint right); // WHILELS
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(ulong left, ulong right); // WHILELS

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(int left, int right); // WHILELE
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(long left, long right); // WHILELE
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(uint left, uint right); // WHILELS
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(ulong left, ulong right); // WHILELS

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(int left, int right); // WHILELE
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(long left, long right); // WHILELE
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(uint left, uint right); // WHILELS
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(ulong left, ulong right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(int left, int right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(long left, long right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(uint left, uint right); // WHILELS
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(ulong left, ulong right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(int left, int right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(long left, long right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(uint left, uint right); // WHILELS
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(ulong left, ulong right); // WHILELS

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(int left, int right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(long left, long right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(uint left, uint right); // WHILELS
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(ulong left, ulong right); // WHILELS


  public static unsafe Vector<T> GetFfrSingle(); // RDFFR // predicated
  public static unsafe Vector<T> GetFfrDouble(); // RDFFR // predicated

  /// T: float, double
  public static unsafe void SetFfr(Vector<T> value); // WRFFR

}

APIs to remove:

namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{
  // Change of existing APIs to return signed vector instead of unsigned

  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(int left, int right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(long left, long right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(int left, int right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(long left, long right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(int left, int right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(long left, long right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(int left, int right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(long left, long right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(int left, int right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(long left, long right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(int left, int right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(long left, long right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(int left, int right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(long left, long right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right); // WHILELS
  
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(int left, int right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(long left, long right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right); // WHILELS
}

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions