Skip to content

Commit

Permalink
Fixed: Performance Regression with <16 Byte Strings in ToLower/ToUpper
Browse files Browse the repository at this point in the history
  • Loading branch information
Sewer56 committed Nov 16, 2023
1 parent 72aaaf6 commit 5956d49
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/Reloaded.Memory/Extensions/StringExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public static int Count(this string text, char c)
/// </remarks>
[SuppressMessage("ReSharper", "InconsistentNaming")]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe void ToLowerInvariantFast(this ReadOnlySpan<char> text, Span<char> target) => TextInfo.ChangeCase<TextInfo.ToLowerConversion>(text, target);
public static void ToLowerInvariantFast(this ReadOnlySpan<char> text, Span<char> target) => TextInfo.ChangeCase<TextInfo.ToLowerConversion>(text, target);

/// <summary>
/// Converts the given string to upper case (invariant casing), using the fastest possible implementation.
Expand Down
36 changes: 11 additions & 25 deletions src/Reloaded.Memory/Internals/Algorithms/UnstableStringHashLower.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using System.Diagnostics.CodeAnalysis;
using Reloaded.Memory.Exceptions;
using Reloaded.Memory.Utilities;
using static Reloaded.Memory.Internals.Backports.System.Text.Unicode.Utf16Utility;
#if NET7_0_OR_GREATER
using Reloaded.Memory.Extensions;
using Reloaded.Memory.Internals.Backports.System.Text.Unicode;
using Reloaded.Memory.Internals.Backports.System.Globalization;
using System.Numerics;
using System.Runtime.Intrinsics;
Expand Down Expand Up @@ -131,28 +131,28 @@ internal static unsafe UIntPtr UnstableHashVec128Lower(this ReadOnlySpan<char> t
length -= (sizeof(Vector128<ulong>) / sizeof(char)) * 4;

Vector128<ulong> v0 = Vector128.Load((ulong*)ptr);
if (!Utf16Utility.AllCharsInVector128AreAscii(v0))
if (!AllCharsInVector128AreAscii(v0))
goto NotAscii;

hash1_128 = Vector128.Xor(hash1_128, Vector128.BitwiseOr(v0, toLower));
hash1_128 = Vector128.Multiply(hash1_128.AsUInt32(), prime.AsUInt32()).AsUInt64();

v0 = Vector128.Load((ulong*)ptr + 2);
if (!Utf16Utility.AllCharsInVector128AreAscii(v0))
if (!AllCharsInVector128AreAscii(v0))
goto NotAscii;

hash2_128 = Vector128.Xor(hash2_128, Vector128.BitwiseOr(v0, toLower));
hash2_128 = Vector128.Multiply(hash2_128.AsUInt32(), prime.AsUInt32()).AsUInt64();

v0 = Vector128.Load((ulong*)ptr + 4);
if (!Utf16Utility.AllCharsInVector128AreAscii(v0))
if (!AllCharsInVector128AreAscii(v0))
goto NotAscii;

hash1_128 = Vector128.Xor(hash1_128, Vector128.BitwiseOr(v0, toLower));
hash1_128 = Vector128.Multiply(hash1_128.AsUInt32(), prime.AsUInt32()).AsUInt64();

v0 = Vector128.Load((ulong*)ptr + 6);
if (!Utf16Utility.AllCharsInVector128AreAscii(v0))
if (!AllCharsInVector128AreAscii(v0))
goto NotAscii;

hash2_128 = Vector128.Xor(hash2_128, Vector128.BitwiseOr(v0, toLower));
Expand All @@ -165,7 +165,7 @@ internal static unsafe UIntPtr UnstableHashVec128Lower(this ReadOnlySpan<char> t
length -= sizeof(Vector128<ulong>) / sizeof(char);

Vector128<ulong> v0 = Vector128.Load((ulong*)ptr);
if (!Utf16Utility.AllCharsInVector128AreAscii(v0))
if (!AllCharsInVector128AreAscii(v0))
goto NotAscii;

hash1_128 = Vector128.Xor(hash1_128, Vector128.BitwiseOr(v0, toLower));
Expand Down Expand Up @@ -251,28 +251,28 @@ internal static unsafe UIntPtr UnstableHashAvx2Lower(this ReadOnlySpan<char> tex
length -= (sizeof(Vector256<ulong>) / sizeof(char)) * 4;

Vector256<ulong> v0 = Vector256.Load((ulong*)ptr);
if (!Utf16Utility.AllCharsInVector256AreAscii(v0))
if (!AllCharsInVector256AreAscii(v0))
goto NotAscii;

hash1Avx = Avx2.Xor(hash1Avx, Avx2.Or(v0, toLower));
hash1Avx = Avx2.Multiply(hash1Avx.AsUInt32(), prime.AsUInt32());

v0 = Vector256.Load((ulong*)ptr + 4);
if (!Utf16Utility.AllCharsInVector256AreAscii(v0))
if (!AllCharsInVector256AreAscii(v0))
goto NotAscii;

hash2Avx = Avx2.Xor(hash2Avx, Avx2.Or(v0, toLower));
hash2Avx = Avx2.Multiply(hash2Avx.AsUInt32(), prime.AsUInt32());

v0 = Vector256.Load((ulong*)ptr + 8);
if (!Utf16Utility.AllCharsInVector256AreAscii(v0))
if (!AllCharsInVector256AreAscii(v0))
goto NotAscii;

hash1Avx = Avx2.Xor(hash1Avx, Avx2.Or(v0, toLower));
hash1Avx = Avx2.Multiply(hash1Avx.AsUInt32(), prime.AsUInt32());

v0 = Vector256.Load((ulong*)ptr + 12);
if (!Utf16Utility.AllCharsInVector256AreAscii(v0))
if (!AllCharsInVector256AreAscii(v0))
goto NotAscii;

hash2Avx = Avx2.Xor(hash2Avx, Avx2.Or(v0, toLower));
Expand All @@ -285,7 +285,7 @@ internal static unsafe UIntPtr UnstableHashAvx2Lower(this ReadOnlySpan<char> tex
length -= sizeof(Vector256<ulong>) / sizeof(char);

Vector256<ulong> v0 = Vector256.Load((ulong*)ptr);
if (!Utf16Utility.AllCharsInVector256AreAscii(v0))
if (!AllCharsInVector256AreAscii(v0))
goto NotAscii;

hash1Avx = Avx2.Xor(hash1Avx, Avx2.Or(v0, toLower));
Expand Down Expand Up @@ -573,20 +573,6 @@ internal static unsafe UIntPtr UnstableHashNonVectorLower64(this ReadOnlySpan<ch
return GetHashCodeUnstableLowerSlow(text);
}

/// <summary>
/// Returns true iff the 64-bit nuint represents all ASCII UTF-16 characters in machine endianness.
/// </summary>
/// <param name="value">The value to assert.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool AllCharsInULongAreAscii(ulong value) => (value & ~0x007F_007F_007F_007Fu) == 0;

/// <summary>
/// Returns true iff the 32-bit nuint represents all ASCII UTF-16 characters in machine endianness.
/// </summary>
/// <param name="value">The value to assert.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool AllCharsInUIntAreAscii(uint value) => (value & ~0x007F_007F) == 0;

#if (NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_1_OR_GREATER) && !NET7_0_OR_GREATER
private unsafe struct ChangeCaseParams(char* first, int length)
{
Expand Down
192 changes: 167 additions & 25 deletions src/Reloaded.Memory/Internals/Backports/System/Globalization/TextInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.Intrinsics;
using Reloaded.Memory.Exceptions;
using Reloaded.Memory.Internals.Backports.System.Text.Unicode;
// ReSharper disable UnusedType.Global

Expand Down Expand Up @@ -53,29 +54,7 @@ public static void ChangeCase<TConversion>(ReadOnlySpan<char> source, Span<char>
else if (Vector128.IsHardwareAccelerated && source.Length >= Vector128<ushort>.Count)
ChangeCase_Vector128<TConversion>(ref MemoryMarshal.GetReference(source), ref MemoryMarshal.GetReference(destination), source.Length);
else
{
var toUpper = typeof(TConversion) == typeof(ToUpperConversion);
ChangeCase_Fallback(source, destination, toUpper);
}
}

private static void ChangeCase_Fallback(ReadOnlySpan<char> source, Span<char> destination, bool toUpper)
{
try
{
if (toUpper)
source.ToUpperInvariant(destination);
else
source.ToLowerInvariant(destination);
}
catch (InvalidOperationException)
{
// Overlapping buffers
if (toUpper)
source.ToString().AsSpan().ToUpperInvariant(destination);
else
source.ToString().AsSpan().ToLowerInvariant(destination);
}
ChangeCase_Under16B<TConversion>(source, destination);
}

/// <summary>
Expand Down Expand Up @@ -134,7 +113,7 @@ public static void ChangeCase_Vector256<TConversion>(ref char source, ref char d
var length = charCount - (int)i;
var srcSpan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref source, i), length);
var dstSpan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref destination, i), length);
ChangeCase_Fallback(srcSpan, dstSpan, toUpper);
ChangeCase_Fallback<TConversion>(srcSpan, dstSpan);
}

/// <summary>
Expand Down Expand Up @@ -193,7 +172,7 @@ public static void ChangeCase_Vector128<TConversion>(ref char source, ref char d
var length = charCount - (int)i;
var srcSpan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref source, i), length);
var dstSpan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref destination, i), length);
ChangeCase_Fallback(srcSpan, dstSpan, toUpper);
ChangeCase_Fallback<TConversion>(srcSpan, dstSpan);
}

// A dummy struct that is used for 'ToUpper' in generic parameters
Expand All @@ -207,5 +186,168 @@ private unsafe struct ChangeCaseParams(char* first, int length)
public readonly char* First = first;
public readonly int Length = length;
};

/// <summary>
/// An implementation of Change Case for inputs up to 16 bytes.
/// Custom, not taken from runtime.
/// </summary>
/// <param name="source">Source span.</param>
/// <param name="destination">Destination span.</param>
public static unsafe void ChangeCase_Under16B<TConversion>(ReadOnlySpan<char> source, Span<char> destination) where TConversion : struct
{
var length = source.Length;

// JIT will treat this as a constant in release builds
var toUpper = typeof(TConversion) == typeof(ToUpperConversion);

// 32 bit implementation
if (sizeof(nuint) == 4)
{
ref uint srcNuintPtr = ref Unsafe.As<char, uint>(ref MemoryMarshal.GetReference(source));
ref uint dstNuintPtr = ref Unsafe.As<char, uint>(ref MemoryMarshal.GetReference(destination));

// 32 bit implementation
// range: 0-7 chars (0-14 bytes)
// keep converting 4 bytes at once until we are left with 0-2
while (length >= 2)
{
length -= 2;
if (!Utf16Utility.AllCharsInUIntAreAscii(srcNuintPtr))
goto NotAscii;

dstNuintPtr = toUpper
? Utf8Utility.ConvertAllAsciiBytesInUInt32ToUppercase(srcNuintPtr)
: Utf8Utility.ConvertAllAsciiBytesInUInt32ToLowercase(srcNuintPtr);

srcNuintPtr = ref Unsafe.Add(ref srcNuintPtr, 1);
dstNuintPtr = ref Unsafe.Add(ref dstNuintPtr, 1);
}

ref char srcCharPtr = ref Unsafe.As<uint, char>(ref srcNuintPtr);
ref char dstCharPtr = ref Unsafe.As<uint, char>(ref dstNuintPtr);
if (length > 0)
{
if (toUpper)
{
if (UnicodeUtility.IsInRangeInclusive(srcCharPtr, 'a', 'z'))
{
dstCharPtr = (char)(srcCharPtr - (char)0x20u);
return;
}

goto NotAscii;
}
else
{
if (UnicodeUtility.IsInRangeInclusive(srcCharPtr, 'A', 'Z'))
{
dstCharPtr = (char)(srcCharPtr + (char)0x20u);
return;
}

goto NotAscii;
}
}

return;
}

// 64 bit implementation
if (sizeof(nuint) == 8)
{
ref nuint srcNuintPtr = ref Unsafe.As<char, nuint>(ref MemoryMarshal.GetReference(source));
ref nuint dstNuintPtr = ref Unsafe.As<char, nuint>(ref MemoryMarshal.GetReference(destination));

// range: 0-7 chars (0-14 bytes)
// -4 chars
if (length >= 4)
{
length -= sizeof(nuint) / sizeof(char);
if (!Utf16Utility.AllCharsInNuintAreAscii(srcNuintPtr))
goto NotAscii;

dstNuintPtr = toUpper
? (nuint)Utf8Utility.ConvertAllAsciiBytesInUInt64ToUppercase(srcNuintPtr)
: (nuint)Utf8Utility.ConvertAllAsciiBytesInUInt64ToLowercase(srcNuintPtr);

srcNuintPtr = ref Unsafe.Add(ref srcNuintPtr, 1);
dstNuintPtr = ref Unsafe.Add(ref dstNuintPtr, 1);
}

// -2 chars
ref uint srcUIntPtr = ref Unsafe.As<nuint, uint>(ref srcNuintPtr);
ref uint dstUIntPtr = ref Unsafe.As<nuint, uint>(ref dstNuintPtr);
if (length >= 2)
{
length -= 2;
if (!Utf16Utility.AllCharsInUIntAreAscii(srcUIntPtr))
goto NotAscii;

dstUIntPtr = toUpper
? Utf8Utility.ConvertAllAsciiBytesInUInt32ToUppercase(srcUIntPtr)
: Utf8Utility.ConvertAllAsciiBytesInUInt32ToLowercase(srcUIntPtr);

srcUIntPtr = ref Unsafe.Add(ref srcUIntPtr, 1);
dstUIntPtr = ref Unsafe.Add(ref dstUIntPtr, 1);
}

// -1 char
ref char srcCharPtr = ref Unsafe.As<uint, char>(ref srcUIntPtr);
ref char dstCharPtr = ref Unsafe.As<uint, char>(ref dstUIntPtr);
if (length >= 1)
{
if (toUpper)
{
if (UnicodeUtility.IsInRangeInclusive(srcCharPtr, 'a', 'z'))
{
dstCharPtr = (char)(srcCharPtr - (char)0x20u);
return;
}

goto NotAscii;
}
else
{
if (UnicodeUtility.IsInRangeInclusive(srcCharPtr, 'A', 'Z'))
{
dstCharPtr = (char)(srcCharPtr + (char)0x20u);
return;
}

goto NotAscii;
}
}

return;
}

ThrowHelpers.ThrowArchitectureNotSupportedException();
return;

NotAscii:
ChangeCase_Fallback<TConversion>(source, destination);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void ChangeCase_Fallback<TConversion>(ReadOnlySpan<char> source, Span<char> destination)
{
// JIT will treat this as a constant in release builds
var toUpper = typeof(TConversion) == typeof(ToUpperConversion);
try
{
if (toUpper)
source.ToUpperInvariant(destination);
else
source.ToLowerInvariant(destination);
}
catch (InvalidOperationException)
{
// Overlapping buffers
if (toUpper)
source.ToString().AsSpan().ToUpperInvariant(destination);
else
source.ToString().AsSpan().ToLowerInvariant(destination);
}
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Diagnostics.CodeAnalysis;

[ExcludeFromCodeCoverage] // Taken from Runtime
internal static class UnicodeUtility
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool IsInRangeInclusive(uint value, uint lowerBound, uint upperBound) => (value - lowerBound) <= (upperBound - lowerBound);
}
Loading

0 comments on commit 5956d49

Please sign in to comment.