Skip to content

Commit

Permalink
JIT: Preserve range check for HW intrinsics with non-const/out-of-ran…
Browse files Browse the repository at this point in the history
…ge immediates (#106765)
  • Loading branch information
amanasifkhalid authored Aug 26, 2024
1 parent fb8e078 commit 9a31a5b
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 41 deletions.
3 changes: 1 addition & 2 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3446,7 +3446,6 @@ class Compiler
GenTreeHWIntrinsic* gtNewScalarHWIntrinsicNode(
var_types type, GenTree* op1, GenTree* op2, GenTree* op3, NamedIntrinsic hwIntrinsicID);
CorInfoType getBaseJitTypeFromArgIfNeeded(NamedIntrinsic intrinsic,
CORINFO_CLASS_HANDLE clsHnd,
CORINFO_SIG_INFO* sig,
CorInfoType simdBaseJitType);

Expand Down Expand Up @@ -4718,7 +4717,7 @@ class Compiler
GenTree* getArgForHWIntrinsic(var_types argType, CORINFO_CLASS_HANDLE argClass);
GenTree* impNonConstFallback(NamedIntrinsic intrinsic, var_types simdType, CorInfoType simdBaseJitType);
GenTree* addRangeCheckIfNeeded(
NamedIntrinsic intrinsic, GenTree* immOp, bool mustExpand, int immLowerBound, int immUpperBound);
NamedIntrinsic intrinsic, GenTree* immOp, int immLowerBound, int immUpperBound);
GenTree* addRangeCheckForHWIntrinsic(GenTree* immOp, int immLowerBound, int immUpperBound);

void getHWIntrinsicImmOps(NamedIntrinsic intrinsic,
Expand Down
115 changes: 85 additions & 30 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,18 +718,16 @@ uint8_t TernaryLogicInfo::GetTernaryControlByte(const TernaryLogicInfo& info, ui
//
// Arguments:
// intrinsic -- id of the intrinsic function.
// clsHnd -- class handle containing the intrinsic function.
// method -- method handle of the intrinsic function.
// sig -- signature of the intrinsic call.
// simdBaseJitType -- Predetermined simdBaseJitType, could be CORINFO_TYPE_UNDEF
//
// Return Value:
// The basetype of intrinsic of it can be fetched from 1st or 2nd argument, else return baseType unmodified.
//
CorInfoType Compiler::getBaseJitTypeFromArgIfNeeded(NamedIntrinsic intrinsic,
CORINFO_CLASS_HANDLE clsHnd,
CORINFO_SIG_INFO* sig,
CorInfoType simdBaseJitType)
CorInfoType Compiler::getBaseJitTypeFromArgIfNeeded(NamedIntrinsic intrinsic,
CORINFO_SIG_INFO* sig,
CorInfoType simdBaseJitType)
{
if (HWIntrinsicInfo::BaseTypeFromSecondArg(intrinsic) || HWIntrinsicInfo::BaseTypeFromFirstArg(intrinsic))
{
Expand Down Expand Up @@ -1332,29 +1330,26 @@ GenTree* Compiler::getArgForHWIntrinsic(var_types argType, CORINFO_CLASS_HANDLE
// Arguments:
// intrinsic -- intrinsic ID
// immOp -- the immediate operand of the intrinsic
// mustExpand -- true if the compiler is compiling the fallback(GT_CALL) of this intrinsics
// immLowerBound -- lower incl. bound for a value of the immediate operand (for a non-full-range imm-intrinsic)
// immUpperBound -- upper incl. bound for a value of the immediate operand (for a non-full-range imm-intrinsic)
//
// Return Value:
// add a GT_BOUNDS_CHECK node for non-full-range imm-intrinsic, which would throw ArgumentOutOfRangeException
// when the imm-argument is not in the valid range
//
GenTree* Compiler::addRangeCheckIfNeeded(
NamedIntrinsic intrinsic, GenTree* immOp, bool mustExpand, int immLowerBound, int immUpperBound)
GenTree* Compiler::addRangeCheckIfNeeded(NamedIntrinsic intrinsic, GenTree* immOp, int immLowerBound, int immUpperBound)
{
assert(immOp != nullptr);
// Full-range imm-intrinsics do not need the range-check
// because the imm-parameter of the intrinsic method is a byte.
// AVX2 Gather intrinsics no not need the range-check
// because their imm-parameter have discrete valid values that are handle by managed code
if (mustExpand && HWIntrinsicInfo::isImmOp(intrinsic, immOp)
if (!immOp->IsCnsIntOrI() && HWIntrinsicInfo::isImmOp(intrinsic, immOp)
#ifdef TARGET_XARCH
&& !HWIntrinsicInfo::isAVX2GatherIntrinsic(intrinsic) && !HWIntrinsicInfo::HasFullRangeImm(intrinsic)
#endif
)
{
assert(!immOp->IsCnsIntOrI());
assert(varTypeIsIntegral(immOp));

return addRangeCheckForHWIntrinsic(immOp, immLowerBound, immUpperBound);
Expand Down Expand Up @@ -1596,7 +1591,6 @@ bool Compiler::CheckHWIntrinsicImmRange(NamedIntrinsic intrinsic,

if (immOutOfRange)
{
assert(!mustExpand);
// The imm-HWintrinsics that do not accept all imm8 values may throw
// ArgumentOutOfRangeException when the imm argument is not in the valid range,
// unless the intrinsic can be transformed into one that does accept all imm8 values
Expand Down Expand Up @@ -1764,7 +1758,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
}

simdBaseJitType = getBaseJitTypeFromArgIfNeeded(intrinsic, clsHnd, sig, simdBaseJitType);
simdBaseJitType = getBaseJitTypeFromArgIfNeeded(intrinsic, sig, simdBaseJitType);
unsigned simdSize = 0;

if (simdBaseJitType == CORINFO_TYPE_UNDEF)
{
Expand All @@ -1783,7 +1778,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

simdBaseJitType = getBaseJitTypeAndSizeOfSIMDType(clsHnd, &sizeBytes);

#if defined(TARGET_ARM64)
#ifdef TARGET_ARM64
if (simdBaseJitType == CORINFO_TYPE_UNDEF && HWIntrinsicInfo::HasScalarInputVariant(intrinsic))
{
// Did not find a valid vector type. The intrinsic has alternate scalar version. Switch to that.
Expand All @@ -1799,12 +1794,40 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(simdBaseJitType != CORINFO_TYPE_VALUECLASS);
}
else
#endif
#endif // TARGET_ARM64
{
assert((category == HW_Category_Special) || (category == HW_Category_Helper) || (sizeBytes != 0));
}
}
}
#ifdef TARGET_ARM64
else if ((simdBaseJitType == CORINFO_TYPE_VALUECLASS) && (HWIntrinsicInfo::BaseTypeFromValueTupleArg(intrinsic)))
{
// If HW_Flag_BaseTypeFromValueTupleArg is set, one of the base type position flags must be set.
// There is no point to using this flag if the SIMD size is known at compile-time.
assert(HWIntrinsicInfo::BaseTypeFromFirstArg(intrinsic) || HWIntrinsicInfo::BaseTypeFromSecondArg(intrinsic));
assert(!HWIntrinsicInfo::tryLookupSimdSize(intrinsic, &simdSize));

CORINFO_ARG_LIST_HANDLE arg = sig->args;

if (HWIntrinsicInfo::BaseTypeFromSecondArg(intrinsic))
{
arg = info.compCompHnd->getArgNext(arg);
}

CORINFO_CLASS_HANDLE argClass = info.compCompHnd->getArgClass(sig, arg);
INDEBUG(unsigned fieldCount = info.compCompHnd->getClassNumInstanceFields(argClass));
assert(fieldCount > 1);

CORINFO_CLASS_HANDLE classHnd;
CORINFO_FIELD_HANDLE fieldHandle = info.compCompHnd->getFieldInClass(argClass, 0);
CorInfoType fieldType = info.compCompHnd->getFieldType(fieldHandle, &classHnd);
assert(isIntrinsicType(classHnd));

simdBaseJitType = getBaseJitTypeAndSizeOfSIMDType(classHnd, &simdSize);
assert(simdSize > 0);
}
#endif // TARGET_ARM64

// Immediately return if the category is other than scalar/special and this is not a supported base type.
if ((category != HW_Category_Special) && (category != HW_Category_Scalar) && !HWIntrinsicInfo::isScalarIsa(isa) &&
Expand All @@ -1827,7 +1850,9 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
#endif // TARGET_XARCH
}

const unsigned simdSize = HWIntrinsicInfo::lookupSimdSize(this, intrinsic, sig);
// We may have already determined simdSize for intrinsics that require special handling.
// If so, skip the lookup.
simdSize = (simdSize == 0) ? HWIntrinsicInfo::lookupSimdSize(this, intrinsic, sig) : simdSize;

HWIntrinsicSignatureReader sigReader;
sigReader.Read(info.compCompHnd, sig);
Expand Down Expand Up @@ -1859,15 +1884,30 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
{
return impNonConstFallback(intrinsic, retType, simdBaseJitType);
}
else if (!opts.OptimizationEnabled())
else if (immOp2->IsCnsIntOrI())
{
// Only enable late stage rewriting if optimizations are enabled
// as we won't otherwise encounter a constant at the later point
return nullptr;
// If we know the immediate is out-of-range,
// convert the intrinsic into a user call (or throw if we must expand)
return impUnsupportedNamedIntrinsic(CORINFO_HELP_THROW_ARGUMENTOUTOFRANGEEXCEPTION, method, sig,
mustExpand);
}
else
{
setMethodHandle = true;
// The immediate is unknown, and we aren't using a fallback intrinsic.
// In this case, CheckHWIntrinsicImmRange should not return false for intrinsics that must expand.
assert(!mustExpand);

if (opts.OptimizationEnabled())
{
// Only enable late stage rewriting if optimizations are enabled
// as we won't otherwise encounter a constant at the later point
setMethodHandle = true;
}
else
{
// Just convert to a user call
return nullptr;
}
}
}
}
Expand Down Expand Up @@ -1896,15 +1936,30 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
{
return impNonConstFallback(intrinsic, retType, simdBaseJitType);
}
else if (!opts.OptimizationEnabled())
else if (immOp1->IsCnsIntOrI())
{
// Only enable late stage rewriting if optimizations are enabled
// as we won't otherwise encounter a constant at the later point
return nullptr;
// If we know the immediate is out-of-range,
// convert the intrinsic into a user call (or throw if we must expand)
return impUnsupportedNamedIntrinsic(CORINFO_HELP_THROW_ARGUMENTOUTOFRANGEEXCEPTION, method, sig,
mustExpand);
}
else
{
setMethodHandle = true;
// The immediate is unknown, and we aren't using a fallback intrinsic.
// In this case, CheckHWIntrinsicImmRange should not return false for intrinsics that must expand.
assert(!mustExpand);

if (opts.OptimizationEnabled())
{
// Only enable late stage rewriting if optimizations are enabled
// as we won't otherwise encounter a constant at the later point
setMethodHandle = true;
}
else
{
// Just convert to a user call
return nullptr;
}
}
}
}
Expand Down Expand Up @@ -1960,7 +2015,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
{
case 4:
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op4 = addRangeCheckIfNeeded(intrinsic, op4, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
Expand All @@ -1974,7 +2029,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 2:
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

Expand Down Expand Up @@ -2144,7 +2199,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
#ifdef TARGET_ARM64
if (intrinsic == NI_AdvSimd_LoadAndInsertScalar)
{
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, immLowerBound, immUpperBound);

if (op1->OperIs(GT_CAST))
{
Expand All @@ -2158,12 +2213,12 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
else if ((intrinsic == NI_AdvSimd_Insert) || (intrinsic == NI_AdvSimd_InsertScalar))
{
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, immLowerBound, immUpperBound);
}
else
#endif
{
op3 = addRangeCheckIfNeeded(intrinsic, op3, mustExpand, immLowerBound, immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, immLowerBound, immUpperBound);
}

retNode = isScalar ? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, op3, intrinsic)
Expand Down
10 changes: 10 additions & 0 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ enum HWIntrinsicFlag : unsigned int
// (instead of merging).
HW_Flag_ZeroingMaskedOperation = 0x800000,

// The intrinsic has an overload where the base type is extracted from a ValueTuple of SIMD types
// (HW_Flag_BaseTypeFrom{First, Second}Arg must also be set to denote the position of the ValueTuple)
HW_Flag_BaseTypeFromValueTupleArg = 0x1000000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -988,6 +992,12 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_ZeroingMaskedOperation) != 0;
}

static bool BaseTypeFromValueTupleArg(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_BaseTypeFromValueTupleArg) != 0;
}

static NamedIntrinsic GetScalarInputVariant(NamedIntrinsic id)
{
assert(HasScalarInputVariant(id));
Expand Down
14 changes: 7 additions & 7 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,

assert(HWIntrinsicInfo::isImmOp(intrinsic, op3));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 1, &immLowerBound, &immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, (!op3->IsCnsIntOrI()), immLowerBound, immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, immLowerBound, immUpperBound);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = getArgForHWIntrinsic(argType, argClass);

Expand Down Expand Up @@ -2939,11 +2939,11 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,

assert(HWIntrinsicInfo::isImmOp(intrinsic, op2));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 1, &immLowerBound, &immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, (!op2->IsCnsIntOrI()), immLowerBound, immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, immLowerBound, immUpperBound);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op3));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 2, &immLowerBound, &immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, (!op3->IsCnsIntOrI()), immLowerBound, immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, immLowerBound, immUpperBound);

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
Expand Down Expand Up @@ -3010,7 +3010,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
op3 = getArgForHWIntrinsic(argType, argClass);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op3));
op3 = addRangeCheckIfNeeded(intrinsic, op3, mustExpand, immLowerBound, immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, immLowerBound, immUpperBound);

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg2, &argClass)));
op2 = getArgForHWIntrinsic(argType, argClass);
Expand Down Expand Up @@ -3040,7 +3040,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
op4 = getArgForHWIntrinsic(argType, argClass);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op4));
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op4 = addRangeCheckIfNeeded(intrinsic, op4, immLowerBound, immUpperBound);

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = getArgForHWIntrinsic(argType, argClass);
Expand Down Expand Up @@ -3108,12 +3108,12 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg5, &argClass)));
GenTree* op5 = getArgForHWIntrinsic(argType, argClass);
assert(HWIntrinsicInfo::isImmOp(intrinsic, op5));
op5 = addRangeCheckIfNeeded(intrinsic, op5, mustExpand, imm1LowerBound, imm1UpperBound);
op5 = addRangeCheckIfNeeded(intrinsic, op5, imm1LowerBound, imm1UpperBound);

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg4, &argClass)));
op4 = getArgForHWIntrinsic(argType, argClass);
assert(HWIntrinsicInfo::isImmOp(intrinsic, op4));
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, imm2LowerBound, imm2UpperBound);
op4 = addRangeCheckIfNeeded(intrinsic, op4, imm2LowerBound, imm2UpperBound);

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = getArgForHWIntrinsic(argType, argClass);
Expand Down
Loading

0 comments on commit 9a31a5b

Please sign in to comment.