Skip to content

Commit

Permalink
Minor cleanup of the Vector64/128/256/512 implementations to improve …
Browse files Browse the repository at this point in the history
…fallbacks (#103095)

* Minor cleanup of the Vector64/128/256/512 implementations to improve fallbacks

* Ensure gtNewSimdSumNode maintains consistency with the software fallback

* Ensure Vector128.Sum also does pairwise adds for floating-point

* Use the right type in the gtNewSimdBinOpNode call

* Don't regress fallback scenarios using AndNot
  • Loading branch information
tannergooding authored Jun 7, 2024
1 parent 2dba5a3 commit e012fd4
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 314 deletions.
46 changes: 38 additions & 8 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25510,20 +25510,48 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
{
assert(IsBaselineVector512IsaSupportedDebugOnly());
GenTree* op1Dup = fgMakeMultiUse(&op1);
op1 = gtNewSimdGetUpperNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);
simdSize = simdSize / 2;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseJitType, simdSize);

op1 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);

if (varTypeIsFloating(simdBaseType))
{
// We need to ensure deterministic results which requires
// consistently adding values together. Since many operations
// end up operating on 128-bit lanes, we break sum the same way.

op1 = gtNewSimdSumNode(type, op1, simdBaseJitType, 32);
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseJitType, 32);

return gtNewOperNode(GT_ADD, type, op1, op1Dup);
}

simdSize = 32;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseJitType, 32);
}

if (simdSize == 32)
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
GenTree* op1Dup = fgMakeMultiUse(&op1);
op1 = gtNewSimdGetUpperNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD16, op1Dup, simdBaseJitType, simdSize);
simdSize = simdSize / 2;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseJitType, simdSize);

op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD16, op1Dup, simdBaseJitType, simdSize);

if (varTypeIsFloating(simdBaseType))
{
// We need to ensure deterministic results which requires
// consistently adding values together. Since many operations
// end up operating on 128-bit lanes, we break sum the same way.

op1 = gtNewSimdSumNode(type, op1, simdBaseJitType, 16);
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseJitType, 16);

return gtNewOperNode(GT_ADD, type, op1, op1Dup);
}

simdSize = 16;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseJitType, 16);
}

assert(simdSize == 16);
Expand All @@ -25534,6 +25562,7 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
{
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
GenTree* op1Shuffled = fgMakeMultiUse(&op1);

if (compOpportunisticallyDependsOn(InstructionSet_AVX))
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
Expand Down Expand Up @@ -25571,6 +25600,7 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
{
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
GenTree* op1Shuffled = fgMakeMultiUse(&op1);

if (compOpportunisticallyDependsOn(InstructionSet_AVX))
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ public static Vector<float> Ceiling(Vector<float> value)
/// <returns>A vector whose bits come from <paramref name="left" /> or <paramref name="right" /> based on the value of <paramref name="condition" />.</returns>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector<T> ConditionalSelect<T>(Vector<T> condition, Vector<T> left, Vector<T> right) => (left & condition) | (right & ~condition);
public static Vector<T> ConditionalSelect<T>(Vector<T> condition, Vector<T> left, Vector<T> right) => (left & condition) | AndNot(right, condition);

/// <summary>Conditionally selects a value from two vectors on a bitwise basis.</summary>
/// <param name="condition">The mask that is used to select a value from <paramref name="left" /> or <paramref name="right" />.</param>
Expand Down Expand Up @@ -1186,7 +1186,7 @@ public static Vector<T> Min<T>(Vector<T> left, Vector<T> right)
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
/// <returns>The product of <paramref name="left" /> and <paramref name="right" />.</returns>
[Intrinsic]
public static Vector<T> Multiply<T>(T left, Vector<T> right) => left * right;
public static Vector<T> Multiply<T>(T left, Vector<T> right) => right * left;

/// <inheritdoc cref="Vector128.MultiplyAddEstimate(Vector128{double}, Vector128{double}, Vector128{double})" />
[Intrinsic]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,21 @@ public static bool IsHardwareAccelerated
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<T> Abs<T>(Vector128<T> vector)
{
return Create(
Vector64.Abs(vector._lower),
Vector64.Abs(vector._upper)
);
if ((typeof(T) == typeof(byte))
|| (typeof(T) == typeof(ushort))
|| (typeof(T) == typeof(uint))
|| (typeof(T) == typeof(ulong))
|| (typeof(T) == typeof(nuint)))
{
return vector;
}
else
{
return Create(
Vector64.Abs(vector._lower),
Vector64.Abs(vector._upper)
);
}
}

/// <summary>Adds two vectors to compute their sum.</summary>
Expand All @@ -80,13 +91,7 @@ public static Vector128<T> Abs<T>(Vector128<T> vector)
/// <returns>The bitwise-and of <paramref name="left" /> and the ones-complement of <paramref name="right" />.</returns>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<T> AndNot<T>(Vector128<T> left, Vector128<T> right)
{
return Create(
Vector64.AndNot(left._lower, right._lower),
Vector64.AndNot(left._upper, right._upper)
);
}
public static Vector128<T> AndNot<T>(Vector128<T> left, Vector128<T> right) => left & ~right;

/// <summary>Reinterprets a <see cref="Vector128{TFrom}" /> as a new <see cref="Vector128{TTo}" />.</summary>
/// <typeparam name="TFrom">The type of the elements in the input vector.</typeparam>
Expand Down Expand Up @@ -377,10 +382,26 @@ public static Vector<T> AsVector<T>(this Vector128<T> value)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Vector128<T> Ceiling<T>(Vector128<T> vector)
{
return Create(
Vector64.Ceiling(vector._lower),
Vector64.Ceiling(vector._upper)
);
if ((typeof(T) == typeof(byte))
|| (typeof(T) == typeof(short))
|| (typeof(T) == typeof(int))
|| (typeof(T) == typeof(long))
|| (typeof(T) == typeof(nint))
|| (typeof(T) == typeof(nuint))
|| (typeof(T) == typeof(sbyte))
|| (typeof(T) == typeof(ushort))
|| (typeof(T) == typeof(uint))
|| (typeof(T) == typeof(ulong)))
{
return vector;
}
else
{
return Create(
Vector64.Ceiling(vector._lower),
Vector64.Ceiling(vector._upper)
);
}
}

/// <summary>Computes the ceiling of each element in a vector.</summary>
Expand All @@ -406,13 +427,7 @@ internal static Vector128<T> Ceiling<T>(Vector128<T> vector)
/// <exception cref="NotSupportedException">The type of <paramref name="condition" />, <paramref name="left" />, and <paramref name="right" /> (<typeparamref name="T" />) is not supported.</exception>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<T> ConditionalSelect<T>(Vector128<T> condition, Vector128<T> left, Vector128<T> right)
{
return Create(
Vector64.ConditionalSelect(condition._lower, left._lower, right._lower),
Vector64.ConditionalSelect(condition._upper, left._upper, right._upper)
);
}
public static Vector128<T> ConditionalSelect<T>(Vector128<T> condition, Vector128<T> left, Vector128<T> right) => (left & condition) | AndNot(right, condition);

/// <summary>Converts a <see cref="Vector128{Int64}" /> to a <see cref="Vector128{Double}" />.</summary>
/// <param name="vector">The vector to convert.</param>
Expand Down Expand Up @@ -1413,16 +1428,7 @@ public static Vector128<T> CreateScalarUnsafe<T>(T value)
/// <exception cref="NotSupportedException">The type of <paramref name="left" /> and <paramref name="right" /> (<typeparamref name="T" />) is not supported.</exception>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Dot<T>(Vector128<T> left, Vector128<T> right)
{
// Doing this as Dot(lower) + Dot(upper) is important for floating-point determinism
// This is because the underlying dpps instruction on x86/x64 will do this equivalently
// and otherwise the software vs accelerated implementations may differ in returned result.

T result = Vector64.Dot(left._lower, right._lower);
result = Scalar<T>.Add(result, Vector64.Dot(left._upper, right._upper));
return result;
}
public static T Dot<T>(Vector128<T> left, Vector128<T> right) => Sum(left * right);

/// <summary>Compares two vectors to determine if they are equal on a per-element basis.</summary>
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
Expand Down Expand Up @@ -1519,10 +1525,26 @@ public static uint ExtractMostSignificantBits<T>(this Vector128<T> vector)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Vector128<T> Floor<T>(Vector128<T> vector)
{
return Create(
Vector64.Floor(vector._lower),
Vector64.Floor(vector._upper)
);
if ((typeof(T) == typeof(byte))
|| (typeof(T) == typeof(short))
|| (typeof(T) == typeof(int))
|| (typeof(T) == typeof(long))
|| (typeof(T) == typeof(nint))
|| (typeof(T) == typeof(nuint))
|| (typeof(T) == typeof(sbyte))
|| (typeof(T) == typeof(ushort))
|| (typeof(T) == typeof(uint))
|| (typeof(T) == typeof(ulong)))
{
return vector;
}
else
{
return Create(
Vector64.Floor(vector._lower),
Vector64.Floor(vector._upper)
);
}
}

/// <summary>Computes the floor of each element in a vector.</summary>
Expand Down Expand Up @@ -1989,7 +2011,7 @@ public static Vector128<T> Min<T>(Vector128<T> left, Vector128<T> right)
/// <returns>The product of <paramref name="left" /> and <paramref name="right" />.</returns>
/// <exception cref="NotSupportedException">The type of <paramref name="left" /> and <paramref name="right"/> (<typeparamref name="T" />) is not supported.</exception>
[Intrinsic]
public static Vector128<T> Multiply<T>(T left, Vector128<T> right) => left * right;
public static Vector128<T> Multiply<T>(T left, Vector128<T> right) => right * left;

/// <inheritdoc cref="Vector64.MultiplyAddEstimate(Vector64{double}, Vector64{double}, Vector64{double})" />
[Intrinsic]
Expand Down Expand Up @@ -2735,14 +2757,13 @@ public static void StoreUnsafe<T>(this Vector128<T> source, ref T destination, n
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Sum<T>(Vector128<T> vector)
{
T sum = default!;

for (int index = 0; index < Vector128<T>.Count; index++)
{
sum = Scalar<T>.Add(sum, vector.GetElementUnsafe(index));
}
// Doing this as Sum(lower) + Sum(upper) is important for floating-point determinism
// This is because the underlying dpps instruction on x86/x64 will do this equivalently
// and otherwise the software vs accelerated implementations may differ in returned result.

return sum;
T result = Vector64.Sum(vector._lower);
result = Scalar<T>.Add(result, Vector64.Sum(vector._upper));
return result;
}

/// <summary>Converts the given vector to a scalar containing the value of the first element.</summary>
Expand Down
Loading

0 comments on commit e012fd4

Please sign in to comment.