Add basic support for folding SIMD intrinsics (#81547)
authorTanner Gooding <tagoo@outlook.com>
Sat, 11 Feb 2023 23:30:05 +0000 (15:30 -0800)
committerGitHub <noreply@github.com>
Sat, 11 Feb 2023 23:30:05 +0000 (15:30 -0800)
* Create a helper ValueNumStore::EvalHWIntrinsicFunBinary

* Adding some basic support for folding SIMD unary and binary operations

* Refactor SIMD constant folding logic to not depend on C++ 14

* Apply formatting patch

* Handle identity folding for simd add/sub

* Add some basic tests covering SIMD constant folding

* Move genTreeOps to its own header so simd.h can use it in Evaluate*Simd

* Applying formatting patch

src/coreclr/jit/CMakeLists.txt
src/coreclr/jit/gentree.h
src/coreclr/jit/gentreeopsdef.h [new file with mode: 0644]
src/coreclr/jit/simd.h
src/coreclr/jit/valuenum.cpp
src/coreclr/jit/valuenum.h
src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.cs [new file with mode: 0644]
src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.csproj [new file with mode: 0644]

index b39a731..a3d6568 100644 (file)
@@ -296,6 +296,7 @@ set( JIT_HEADERS
   emitpub.h
   error.h
   gentree.h
+  gentreeopsdef.h
   gtlist.h
   gtstructs.h
   hashbv.h
index 05c0d5a..7375e56 100644 (file)
@@ -24,6 +24,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
 #include "valuenumtype.h"
 #include "jitstd.h"
 #include "jithashtable.h"
+#include "gentreeopsdef.h"
 #include "simd.h"
 #include "namedintrinsiclist.h"
 #include "layout.h"
@@ -64,24 +65,6 @@ enum SpecialCodeKind
 
 /*****************************************************************************/
 
-enum genTreeOps : BYTE
-{
-#define GTNODE(en, st, cm, ok) GT_##en,
-#include "gtlist.h"
-
-    GT_COUNT,
-
-#ifdef TARGET_64BIT
-    // GT_CNS_NATIVELONG is the gtOper symbol for GT_CNS_LNG or GT_CNS_INT, depending on the target.
-    // For the 64-bit targets we will only use GT_CNS_INT as it used to represent all the possible sizes
-    GT_CNS_NATIVELONG = GT_CNS_INT,
-#else
-    // For the 32-bit targets we use a GT_CNS_LNG to hold a 64-bit integer constant and GT_CNS_INT for all others.
-    // In the future when we retarget the JIT for x86 we should consider eliminating GT_CNS_LNG
-    GT_CNS_NATIVELONG = GT_CNS_LNG,
-#endif
-};
-
 // The following enum defines a set of bit flags that can be used
 // to classify expression tree nodes.
 //
diff --git a/src/coreclr/jit/gentreeopsdef.h b/src/coreclr/jit/gentreeopsdef.h
new file mode 100644 (file)
index 0000000..9e8e915
--- /dev/null
@@ -0,0 +1,29 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+/*****************************************************************************/
+#ifndef _GENTREEOPSDEF_H_
+#define _GENTREEOPSDEF_H_
+/*****************************************************************************/
+
+enum genTreeOps : BYTE
+{
+#define GTNODE(en, st, cm, ok) GT_##en,
+#include "gtlist.h"
+
+    GT_COUNT,
+
+#ifdef TARGET_64BIT
+    // GT_CNS_NATIVELONG is the gtOper symbol for GT_CNS_LNG or GT_CNS_INT, depending on the target.
+    // For the 64-bit targets we will only use GT_CNS_INT as it used to represent all the possible sizes
+    GT_CNS_NATIVELONG = GT_CNS_INT,
+#else
+    // For the 32-bit targets we use a GT_CNS_LNG to hold a 64-bit integer constant and GT_CNS_INT for all others.
+    // In the future when we retarget the JIT for x86 we should consider eliminating GT_CNS_LNG
+    GT_CNS_NATIVELONG = GT_CNS_LNG,
+#endif
+};
+
+/*****************************************************************************/
+#endif // _GENTREEOPSDEF_H_
+/*****************************************************************************/
index be227ac..526c032 100644 (file)
@@ -72,6 +72,13 @@ struct simd12_t
         uint8_t  u8[12];
         uint16_t u16[6];
         uint32_t u32[3];
+
+        // These three exist to simplify templatized code
+        // they won't actually be accessed for real scenarios
+
+        double   f64[1];
+        int64_t  i64[1];
+        uint64_t u64[1];
     };
 
     bool operator==(const simd12_t& other) const
@@ -142,6 +149,252 @@ struct simd32_t
     }
 };
 
+template <typename TBase>
+TBase EvaluateUnaryScalar(genTreeOps oper, TBase arg0)
+{
+    switch (oper)
+    {
+        case GT_NEG:
+        {
+            return static_cast<TBase>(0) - arg0;
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+template <typename TSimd, typename TBase>
+void EvaluateUnarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0)
+{
+    uint32_t count = sizeof(TSimd) / sizeof(TBase);
+
+    if (scalar)
+    {
+        count = 1;
+
+#if defined(TARGET_XARCH)
+        // scalar operations on xarch copy the upper bits from arg0
+        *result = arg0;
+#elif defined(TARGET_ARM64)
+        // scalar operations on arm64 zero the upper bits
+        *result = {};
+#endif
+    }
+
+    for (uint32_t i = 0; i < count; i++)
+    {
+        // Safely execute `result[i] = oper(arg0[i])`
+
+        TBase input0;
+        memcpy(&input0, &arg0.u8[i * sizeof(TBase)], sizeof(TBase));
+
+        TBase output = EvaluateUnaryScalar<TBase>(oper, input0);
+        memcpy(&result->u8[i * sizeof(TBase)], &output, sizeof(TBase));
+    }
+}
+
+template <typename TSimd>
+void EvaluateUnarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd* result, TSimd arg0)
+{
+    switch (baseType)
+    {
+        case TYP_FLOAT:
+        {
+            EvaluateUnarySimd<TSimd, float>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_DOUBLE:
+        {
+            EvaluateUnarySimd<TSimd, double>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_BYTE:
+        {
+            EvaluateUnarySimd<TSimd, int8_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_SHORT:
+        {
+            EvaluateUnarySimd<TSimd, int16_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_INT:
+        {
+            EvaluateUnarySimd<TSimd, int32_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_LONG:
+        {
+            EvaluateUnarySimd<TSimd, int64_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_UBYTE:
+        {
+            EvaluateUnarySimd<TSimd, uint8_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_USHORT:
+        {
+            EvaluateUnarySimd<TSimd, uint16_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_UINT:
+        {
+            EvaluateUnarySimd<TSimd, uint32_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        case TYP_ULONG:
+        {
+            EvaluateUnarySimd<TSimd, uint64_t>(oper, scalar, result, arg0);
+            break;
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+template <typename TBase>
+TBase EvaluateBinaryScalar(genTreeOps oper, TBase arg0, TBase arg1)
+{
+    switch (oper)
+    {
+        case GT_ADD:
+        {
+            return arg0 + arg1;
+        }
+
+        case GT_SUB:
+        {
+            return arg0 - arg1;
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+template <typename TSimd, typename TBase>
+void EvaluateBinarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0, TSimd arg1)
+{
+    uint32_t count = sizeof(TSimd) / sizeof(TBase);
+
+    if (scalar)
+    {
+        count = 1;
+
+#if defined(TARGET_XARCH)
+        // scalar operations on xarch copy the upper bits from arg0
+        *result = arg0;
+#elif defined(TARGET_ARM64)
+        // scalar operations on arm64 zero the upper bits
+        *result = {};
+#endif
+    }
+
+    for (uint32_t i = 0; i < count; i++)
+    {
+        // Safely execute `result[i] = oper(arg0[i], arg1[i])`
+
+        TBase input0;
+        memcpy(&input0, &arg0.u8[i * sizeof(TBase)], sizeof(TBase));
+
+        TBase input1;
+        memcpy(&input1, &arg1.u8[i * sizeof(TBase)], sizeof(TBase));
+
+        TBase output = EvaluateBinaryScalar<TBase>(oper, input0, input1);
+        memcpy(&result->u8[i * sizeof(TBase)], &output, sizeof(TBase));
+    }
+}
+
+template <typename TSimd>
+void EvaluateBinarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd* result, TSimd arg0, TSimd arg1)
+{
+    switch (baseType)
+    {
+        case TYP_FLOAT:
+        {
+            EvaluateBinarySimd<TSimd, float>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_DOUBLE:
+        {
+            EvaluateBinarySimd<TSimd, double>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_BYTE:
+        {
+            EvaluateBinarySimd<TSimd, int8_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_SHORT:
+        {
+            EvaluateBinarySimd<TSimd, int16_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_INT:
+        {
+            EvaluateBinarySimd<TSimd, int32_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_LONG:
+        {
+            EvaluateBinarySimd<TSimd, int64_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_UBYTE:
+        {
+            EvaluateBinarySimd<TSimd, uint8_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_USHORT:
+        {
+            EvaluateBinarySimd<TSimd, uint16_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_UINT:
+        {
+            EvaluateBinarySimd<TSimd, uint32_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        case TYP_ULONG:
+        {
+            EvaluateBinarySimd<TSimd, uint64_t>(oper, scalar, result, arg0, arg1);
+            break;
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
 #ifdef FEATURE_SIMD
 
 #ifdef TARGET_XARCH
index 62c6735..c43bb47 100644 (file)
@@ -6012,8 +6012,174 @@ void ValueNumStore::SetVNIsCheckedBound(ValueNum vn)
 }
 
 #ifdef FEATURE_HW_INTRINSICS
-ValueNum ValueNumStore::EvalHWIntrinsicFunUnary(
-    var_types type, NamedIntrinsic ni, VNFunc func, ValueNum arg0VN, bool encodeResultType, ValueNum resultTypeVN)
+ValueNum EvaluateUnarySimd(
+    ValueNumStore* vns, genTreeOps oper, bool scalar, var_types simdType, var_types baseType, ValueNum arg0VN)
+{
+    switch (simdType)
+    {
+        case TYP_SIMD8:
+        {
+            simd8_t result = {};
+            EvaluateUnarySimd<simd8_t>(oper, scalar, baseType, &result, vns->GetConstantSimd8(arg0VN));
+            return vns->VNForSimd8Con(result);
+        }
+
+        case TYP_SIMD12:
+        {
+            simd12_t result = {};
+            EvaluateUnarySimd<simd12_t>(oper, scalar, baseType, &result, vns->GetConstantSimd12(arg0VN));
+            return vns->VNForSimd12Con(result);
+        }
+
+        case TYP_SIMD16:
+        {
+            simd16_t result = {};
+            EvaluateUnarySimd<simd16_t>(oper, scalar, baseType, &result, vns->GetConstantSimd16(arg0VN));
+            return vns->VNForSimd16Con(result);
+        }
+
+        case TYP_SIMD32:
+        {
+            simd32_t result = {};
+            EvaluateUnarySimd<simd32_t>(oper, scalar, baseType, &result, vns->GetConstantSimd32(arg0VN));
+            return vns->VNForSimd32Con(result);
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+ValueNum EvaluateBinarySimd(ValueNumStore* vns,
+                            genTreeOps     oper,
+                            bool           scalar,
+                            var_types      simdType,
+                            var_types      baseType,
+                            ValueNum       arg0VN,
+                            ValueNum       arg1VN)
+{
+    switch (simdType)
+    {
+        case TYP_SIMD8:
+        {
+            simd8_t result = {};
+            EvaluateBinarySimd<simd8_t>(oper, scalar, baseType, &result, vns->GetConstantSimd8(arg0VN),
+                                        vns->GetConstantSimd8(arg1VN));
+            return vns->VNForSimd8Con(result);
+        }
+
+        case TYP_SIMD12:
+        {
+            simd12_t result = {};
+            EvaluateBinarySimd<simd12_t>(oper, scalar, baseType, &result, vns->GetConstantSimd12(arg0VN),
+                                         vns->GetConstantSimd12(arg1VN));
+            return vns->VNForSimd12Con(result);
+        }
+
+        case TYP_SIMD16:
+        {
+            simd16_t result = {};
+            EvaluateBinarySimd<simd16_t>(oper, scalar, baseType, &result, vns->GetConstantSimd16(arg0VN),
+                                         vns->GetConstantSimd16(arg1VN));
+            return vns->VNForSimd16Con(result);
+        }
+
+        case TYP_SIMD32:
+        {
+            simd32_t result = {};
+            EvaluateBinarySimd<simd32_t>(oper, scalar, baseType, &result, vns->GetConstantSimd32(arg0VN),
+                                         vns->GetConstantSimd32(arg1VN));
+            return vns->VNForSimd32Con(result);
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+template <typename TSimd>
+ValueNum EvaluateSimdGetElement(ValueNumStore* vns, var_types baseType, TSimd arg0, int arg1)
+{
+    switch (baseType)
+    {
+        case TYP_FLOAT:
+        {
+            float result = arg0.f32[arg1];
+            return vns->VNForFloatCon(static_cast<float>(result));
+        }
+
+        case TYP_DOUBLE:
+        {
+            double result = arg0.f64[arg1];
+            return vns->VNForDoubleCon(static_cast<double>(result));
+        }
+
+        case TYP_BYTE:
+        {
+            int8_t result = arg0.i8[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_SHORT:
+        {
+            int16_t result = arg0.i16[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_INT:
+        {
+            int32_t result = arg0.i32[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_LONG:
+        {
+            int64_t result = arg0.i64[arg1];
+            return vns->VNForLongCon(static_cast<int64_t>(result));
+        }
+
+        case TYP_UBYTE:
+        {
+            uint8_t result = arg0.u8[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_USHORT:
+        {
+            uint16_t result = arg0.u16[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_UINT:
+        {
+            uint32_t result = arg0.u32[arg1];
+            return vns->VNForIntCon(static_cast<int32_t>(result));
+        }
+
+        case TYP_ULONG:
+        {
+            uint64_t result = arg0.u64[arg1];
+            return vns->VNForLongCon(static_cast<int64_t>(result));
+        }
+
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+ValueNum ValueNumStore::EvalHWIntrinsicFunUnary(var_types      type,
+                                                var_types      baseType,
+                                                NamedIntrinsic ni,
+                                                VNFunc         func,
+                                                ValueNum       arg0VN,
+                                                bool           encodeResultType,
+                                                ValueNum       resultTypeVN)
 {
     if (IsVNConstant(arg0VN))
     {
@@ -6075,6 +6241,18 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunUnary(
 
                 return VNForLongCon(static_cast<int64_t>(result));
             }
+
+            case NI_AdvSimd_Negate:
+            case NI_AdvSimd_Arm64_Negate:
+            {
+                return EvaluateUnarySimd(this, GT_NEG, /* scalar */ false, type, baseType, arg0VN);
+            }
+
+            case NI_AdvSimd_NegateScalar:
+            case NI_AdvSimd_Arm64_NegateScalar:
+            {
+                return EvaluateUnarySimd(this, GT_NEG, /* scalar */ true, type, baseType, arg0VN);
+            }
 #endif // TARGET_ARM64
 
 #if defined(TARGET_XARCH)
@@ -6190,6 +6368,198 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunUnary(
     }
     return VNForFunc(type, func, arg0VN);
 }
+
+ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types      type,
+                                                 var_types      baseType,
+                                                 NamedIntrinsic ni,
+                                                 VNFunc         func,
+                                                 ValueNum       arg0VN,
+                                                 ValueNum       arg1VN,
+                                                 bool           encodeResultType,
+                                                 ValueNum       resultTypeVN)
+{
+    ValueNum cnsVN = NoVN;
+    ValueNum argVN = NoVN;
+
+    if (IsVNConstant(arg0VN))
+    {
+        cnsVN = arg0VN;
+
+        if (!IsVNConstant(arg1VN))
+        {
+            argVN = arg1VN;
+        }
+    }
+    else
+    {
+        argVN = arg0VN;
+
+        if (IsVNConstant(arg1VN))
+        {
+            cnsVN = arg1VN;
+        }
+    }
+
+    if (argVN == NoVN)
+    {
+        assert(IsVNConstant(arg0VN) && IsVNConstant(arg1VN));
+
+        switch (ni)
+        {
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_Add:
+            case NI_AdvSimd_Arm64_Add:
+#else
+            case NI_SSE_Add:
+            case NI_SSE2_Add:
+            case NI_AVX_Add:
+            case NI_AVX2_Add:
+#endif
+            {
+                return EvaluateBinarySimd(this, GT_ADD, /* scalar */ false, type, baseType, arg0VN, arg1VN);
+            }
+
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_AddScalar:
+#else
+            case NI_SSE_AddScalar:
+            case NI_SSE2_AddScalar:
+#endif
+            {
+                return EvaluateBinarySimd(this, GT_ADD, /* scalar */ true, type, baseType, arg0VN, arg1VN);
+            }
+
+#ifdef TARGET_ARM64
+            case NI_Vector64_GetElement:
+#endif
+            case NI_Vector128_GetElement:
+#ifdef TARGET_XARCH
+            case NI_Vector256_GetElement:
+#endif
+            {
+                switch (TypeOfVN(arg0VN))
+                {
+                    case TYP_SIMD8:
+                    {
+                        return EvaluateSimdGetElement<simd8_t>(this, baseType, GetConstantSimd8(arg0VN),
+                                                               GetConstantInt32(arg1VN));
+                    }
+
+                    case TYP_SIMD12:
+                    {
+                        return EvaluateSimdGetElement<simd12_t>(this, baseType, GetConstantSimd12(arg0VN),
+                                                                GetConstantInt32(arg1VN));
+                    }
+
+                    case TYP_SIMD16:
+                    {
+                        return EvaluateSimdGetElement<simd16_t>(this, baseType, GetConstantSimd16(arg0VN),
+                                                                GetConstantInt32(arg1VN));
+                    }
+
+                    case TYP_SIMD32:
+                    {
+                        return EvaluateSimdGetElement<simd32_t>(this, baseType, GetConstantSimd32(arg0VN),
+                                                                GetConstantInt32(arg1VN));
+                    }
+
+                    default:
+                    {
+                        unreached();
+                    }
+                }
+            }
+
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_Subtract:
+            case NI_AdvSimd_Arm64_Subtract:
+#else
+            case NI_SSE_Subtract:
+            case NI_SSE2_Subtract:
+            case NI_AVX_Subtract:
+            case NI_AVX2_Subtract:
+#endif
+            {
+                return EvaluateBinarySimd(this, GT_SUB, /* scalar */ false, type, baseType, arg0VN, arg1VN);
+            }
+
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_SubtractScalar:
+#else
+            case NI_SSE_SubtractScalar:
+            case NI_SSE2_SubtractScalar:
+#endif
+            {
+                return EvaluateBinarySimd(this, GT_SUB, /* scalar */ true, type, baseType, arg0VN, arg1VN);
+            }
+
+            default:
+                break;
+        }
+    }
+    else if (cnsVN != NoVN)
+    {
+        switch (ni)
+        {
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_Add:
+            case NI_AdvSimd_Arm64_Add:
+#else
+            case NI_SSE_Add:
+            case NI_SSE2_Add:
+            case NI_AVX_Add:
+            case NI_AVX2_Add:
+#endif
+            {
+                // Handle `x + 0` and `0 + x`
+
+                ValueNum zeroVN = VNZeroForType(type);
+
+                if (cnsVN == zeroVN)
+                {
+                    return argVN;
+                }
+                break;
+            }
+
+#ifdef TARGET_ARM64
+            case NI_AdvSimd_Subtract:
+            case NI_AdvSimd_Arm64_Subtract:
+#else
+            case NI_SSE_Subtract:
+            case NI_SSE2_Subtract:
+            case NI_AVX_Subtract:
+            case NI_AVX2_Subtract:
+#endif
+            {
+                // Handle `x - 0`
+
+                if (cnsVN != arg1VN)
+                {
+                    // This is `0 - x` which is `NEG(x)`
+                    break;
+                }
+
+                ValueNum zeroVN = VNZeroForType(type);
+
+                if (cnsVN == zeroVN)
+                {
+                    return argVN;
+                }
+                break;
+            }
+
+            default:
+                break;
+        }
+    }
+
+    if (encodeResultType)
+    {
+        return VNForFunc(type, func, arg0VN, arg1VN, resultTypeVN);
+    }
+    return VNForFunc(type, func, arg0VN, arg1VN);
+}
 #endif
 
 ValueNum ValueNumStore::EvalMathFuncUnary(var_types typ, NamedIntrinsic gtMathFN, ValueNum arg0VN)
@@ -9851,12 +10221,13 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
 
             if (tree->GetOperandCount() == 1)
             {
-                ValueNum normalLVN =
-                    vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), intrinsicId, func, op1vnp.GetLiberal(),
-                                                     encodeResultType, resultTypeVNPair.GetLiberal());
+                ValueNum normalLVN = vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), tree->GetSimdBaseType(),
+                                                                      intrinsicId, func, op1vnp.GetLiberal(),
+                                                                      encodeResultType, resultTypeVNPair.GetLiberal());
                 ValueNum normalCVN =
-                    vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), intrinsicId, func, op1vnp.GetConservative(),
-                                                     encodeResultType, resultTypeVNPair.GetConservative());
+                    vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                     op1vnp.GetConservative(), encodeResultType,
+                                                     resultTypeVNPair.GetConservative());
 
                 normalPair = ValueNumPair(normalLVN, normalCVN);
                 excSetPair = op1Xvnp;
@@ -9866,17 +10237,18 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
                 ValueNumPair op2vnp;
                 ValueNumPair op2Xvnp;
                 getOperandVNs(tree->Op(2), &op2vnp, &op2Xvnp);
+
+                ValueNum normalLVN =
+                    vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                      op1vnp.GetLiberal(), op2vnp.GetLiberal(), encodeResultType,
+                                                      resultTypeVNPair.GetLiberal());
+                ValueNum normalCVN =
+                    vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                      op1vnp.GetConservative(), op2vnp.GetConservative(),
+                                                      encodeResultType, resultTypeVNPair.GetConservative());
+
+                normalPair = ValueNumPair(normalLVN, normalCVN);
                 excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
-                if (encodeResultType)
-                {
-                    normalPair = vnStore->VNPairForFunc(tree->TypeGet(), func, op1vnp, op2vnp, resultTypeVNPair);
-                    assert((vnStore->VNFuncArity(func) == 3) || isVariableNumArgs);
-                }
-                else
-                {
-                    normalPair = vnStore->VNPairForFunc(tree->TypeGet(), func, op1vnp, op2vnp);
-                    assert((vnStore->VNFuncArity(func) == 2) || isVariableNumArgs);
-                }
             }
         }
     }
index f08d6ba..5a12207 100644 (file)
@@ -348,10 +348,13 @@ private:
     float GetConstantSingle(ValueNum argVN);
 
 #if defined(FEATURE_SIMD)
+public:
     simd8_t GetConstantSimd8(ValueNum argVN);
     simd12_t GetConstantSimd12(ValueNum argVN);
     simd16_t GetConstantSimd16(ValueNum argVN);
     simd32_t GetConstantSimd32(ValueNum argVN);
+
+private:
 #endif // FEATURE_SIMD
 
     // Assumes that all the ValueNum arguments of each of these functions have been shown to represent constants.
@@ -1123,8 +1126,22 @@ public:
                             EvalMathFuncBinary(typ, mthFunc, arg0VNP.GetConservative(), arg1VNP.GetConservative()));
     }
 
-    ValueNum EvalHWIntrinsicFunUnary(
-        var_types type, NamedIntrinsic ni, VNFunc func, ValueNum arg0VN, bool encodeResultType, ValueNum resultTypeVN);
+    ValueNum EvalHWIntrinsicFunUnary(var_types      type,
+                                     var_types      baseType,
+                                     NamedIntrinsic ni,
+                                     VNFunc         func,
+                                     ValueNum       arg0VN,
+                                     bool           encodeResultType,
+                                     ValueNum       resultTypeVN);
+
+    ValueNum EvalHWIntrinsicFunBinary(var_types      type,
+                                      var_types      baseType,
+                                      NamedIntrinsic ni,
+                                      VNFunc         func,
+                                      ValueNum       arg0VN,
+                                      ValueNum       arg1VN,
+                                      bool           encodeResultType,
+                                      ValueNum       resultTypeVN);
 
     // Returns "true" iff "vn" represents a function application.
     bool IsVNFunc(ValueNum vn);
diff --git a/src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.cs b/src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.cs
new file mode 100644 (file)
index 0000000..56aaec0
--- /dev/null
@@ -0,0 +1,244 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.Intrinsics;
+using Xunit;
+
+public class SimdConstantFoldings
+{
+    [Fact]
+    public static void NegateTests()
+    {
+        Assert.Equal(
+            Vector128.Create((byte)(0xFF), 0x02, 0xFD, 0x04, 0xFB, 0x06, 0xF9, 0x08, 0xF7, 0x0A, 0xF5, 0x0C, 0xF3, 0x0E, 0xF1, 0x10),
+           -Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ushort)(0xFFFF), 0x0002, 0xFFFD, 0x0004, 0xFFFB, 0x0006, 0xFFF9, 0x0008),
+           -Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)
+        );
+
+        Assert.Equal(
+            Vector128.Create((uint)(0xFFFF_FFFF), 0x0000_0002, 0xFFFF_FFFD, 0x0000_0004),
+           -Vector128.Create((uint)(0x0000_0001), 0xFFFF_FFFE, 0x0000_0003, 0xFFFF_FFFC)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ulong)(0xFFFF_FFFF_FFFF_FFFF), 0x0000_0000_0000_0002),
+           -Vector128.Create((ulong)(0x0000_0000_0000_0001), 0xFFFF_FFFF_FFFF_FFFE)
+        );
+
+        Assert.Equal(
+            Vector128.Create((sbyte)(-1), +2, -3, +4, -5, +6, -7, +8, -9, +10, -11, +12, -13, +14, -15, +16),
+           -Vector128.Create((sbyte)(+1), -2, +3, -4, +5, -6, +7, -8, +9, -10, +11, -12, +13, -14, +15, -16)
+        );
+
+        Assert.Equal(
+            Vector128.Create((short)(-1), +2, -3, +4, -5, +6, -7, +8),
+           -Vector128.Create((short)(+1), -2, +3, -4, +5, -6, +7, -8)
+        );
+
+        Assert.Equal(
+            Vector128.Create((int)(-1), +2, -3, +4),
+           -Vector128.Create((int)(+1), -2, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((long)(-1), +2),
+           -Vector128.Create((long)(+1), -2)
+        );
+
+        Assert.Equal(
+            Vector128.Create((float)(-1), +2, -3, +4),
+           -Vector128.Create((float)(+1), -2, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((double)(-1), +2),
+           -Vector128.Create((double)(+1), -2)
+        );
+    }
+
+    [Fact]
+    public static void AddTests()
+    {
+        Assert.Equal(
+            Vector128.Create((byte)(0x02), 0xFF, 0x06, 0xF8, 0x0A, 0xF4, 0x0E, 0xF0, 0x12, 0xEC, 0x16, 0xE8, 0x1A, 0xE4, 0x1E, 0xE0),
+            Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0)
+          + Vector128.Create((byte)(0x01), 0x01, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ushort)(0x0002), 0xFFFF, 0x0006, 0xFFF8, 0x000A, 0xFFF4, 0x000E, 0xFFF0),
+            Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)
+          + Vector128.Create((ushort)(0x0001), 0x0001, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)
+        );
+
+        Assert.Equal(
+            Vector128.Create((uint)(0x0000_0002), 0xFFFF_FFFF, 0x0000_0006, 0xFFFF_FFF8),
+            Vector128.Create((uint)(0x0000_0001), 0xFFFF_FFFE, 0x0000_0003, 0xFFFF_FFFC)
+          + Vector128.Create((uint)(0x0000_0001), 0x0000_0001, 0x0000_0003, 0xFFFF_FFFC)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ulong)(0x0000_0000_0000_0002), 0xFFFF_FFFF_FFFF_FFFF),
+            Vector128.Create((ulong)(0x0000_0000_0000_0001), 0xFFFF_FFFF_FFFF_FFFE)
+          + Vector128.Create((ulong)(0x0000_0000_0000_0001), 0x0000_0000_0000_0001)
+        );
+
+        Assert.Equal(
+            Vector128.Create((sbyte)(+2), -1, +6, -8, +10, -12, +14, -16, +18, -20, +22, -24, +26, -28, +30, -32),
+            Vector128.Create((sbyte)(+1), -2, +3, -4, +05, -06, +07, -08, +09, -10, +11, -12, +13, -14, +15, -16)
+          + Vector128.Create((sbyte)(+1), +1, +3, -4, +05, -06, +07, -08, +09, -10, +11, -12, +13, -14, +15, -16)
+        );
+
+        Assert.Equal(
+            Vector128.Create((short)(+2), -1, +6, -8, +10, -12, +14, -16),
+            Vector128.Create((short)(+1), -2, +3, -4, +05, -06, +07, -08)
+          + Vector128.Create((short)(+1), +1, +3, -4, +05, -06, +07, -08)
+        );
+
+        Assert.Equal(
+            Vector128.Create((int)(+2), -1, +6, -8),
+            Vector128.Create((int)(+1), -2, +3, -4)
+          + Vector128.Create((int)(+1), +1, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((long)(+2), -1),
+            Vector128.Create((long)(+1), -2)
+          + Vector128.Create((long)(+1), +1)
+        );
+
+        Assert.Equal(
+            Vector128.Create((float)(+2), -1, +6, -8),
+            Vector128.Create((float)(+1), -2, +3, -4)
+          + Vector128.Create((float)(+1), +1, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((double)(+2), -1),
+            Vector128.Create((double)(+1), -2)
+          + Vector128.Create((double)(+1), +1)
+        );
+    }
+
+    [Fact]
+    public static void SubtractTests()
+    {
+        Assert.Equal(
+            Vector128.Create((byte)(0x00), 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00),
+            Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0)
+          - Vector128.Create((byte)(0x01), 0x01, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ushort)(0x0000), 0xFFFD, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000),
+            Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)
+          - Vector128.Create((ushort)(0x0001), 0x0001, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)
+        );
+
+        Assert.Equal(
+            Vector128.Create((uint)(0x0000_0000), 0xFFFF_FFFD, 0x0000_0000, 0x0000_0000),
+            Vector128.Create((uint)(0x0000_0001), 0xFFFF_FFFE, 0x0000_0003, 0xFFFF_FFFC)
+          - Vector128.Create((uint)(0x0000_0001), 0x0000_0001, 0x0000_0003, 0xFFFF_FFFC)
+        );
+
+        Assert.Equal(
+            Vector128.Create((ulong)(0x0000_0000_0000_0000), 0xFFFF_FFFF_FFFF_FFFD),
+            Vector128.Create((ulong)(0x0000_0000_0000_0001), 0xFFFF_FFFF_FFFF_FFFE)
+          - Vector128.Create((ulong)(0x0000_0000_0000_0001), 0x0000_0000_0000_0001)
+        );
+
+        Assert.Equal(
+            Vector128.Create((sbyte)(+0), -3, +0, +0, +00, +00, +00, +00, +00, +00, +00, +00, +00, +00, +00, +00),
+            Vector128.Create((sbyte)(+1), -2, +3, -4, +05, -06, +07, -08, +09, -10, +11, -12, +13, -14, +15, -16)
+          - Vector128.Create((sbyte)(+1), +1, +3, -4, +05, -06, +07, -08, +09, -10, +11, -12, +13, -14, +15, -16)
+        );
+
+        Assert.Equal(
+            Vector128.Create((short)(+0), -3, +0, +0, +00, +00, +00, +00),
+            Vector128.Create((short)(+1), -2, +3, -4, +05, -06, +07, -08)
+          - Vector128.Create((short)(+1), +1, +3, -4, +05, -06, +07, -08)
+        );
+
+        Assert.Equal(
+            Vector128.Create((int)(+0), -3, +0, +0),
+            Vector128.Create((int)(+1), -2, +3, -4)
+          - Vector128.Create((int)(+1), +1, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((long)(+0), -3),
+            Vector128.Create((long)(+1), -2)
+          - Vector128.Create((long)(+1), +1)
+        );
+
+        Assert.Equal(
+            Vector128.Create((float)(+0), -3, +0, +0),
+            Vector128.Create((float)(+1), -2, +3, -4)
+          - Vector128.Create((float)(+1), +1, +3, -4)
+        );
+
+        Assert.Equal(
+            Vector128.Create((double)(+0), -3),
+            Vector128.Create((double)(+1), -2)
+          - Vector128.Create((double)(+1), +1)
+        );
+    }
+
+    [Fact]
+    public static void GetElementTests()
+    {
+        Assert.Equal(
+            (byte)(0xFE),
+            Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0).GetElement(1)
+        );
+
+        Assert.Equal(
+            (ushort)(0xFFFE),
+            Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8).GetElement(1)
+        );
+
+        Assert.Equal(
+            (uint)(0xFFFF_FFFE),
+            Vector128.Create((uint)(0x0000_0001), 0xFFFF_FFFE, 0x0000_0003, 0xFFFF_FFFC).GetElement(1)
+        );
+
+        Assert.Equal(
+            (ulong)(0xFFFF_FFFF_FFFF_FFFE),
+            Vector128.Create((ulong)(0x0000_0000_0000_0001), 0xFFFF_FFFF_FFFF_FFFE).GetElement(1)
+        );
+
+        Assert.Equal(
+            (sbyte)(-2),
+            Vector128.Create((sbyte)(+1), -2, +3, -4, +5, -6, +7, -8, +9, -10, +11, -12, +13, -14, +15, -16).GetElement(1)
+        );
+
+        Assert.Equal(
+            (short)(-2),
+            Vector128.Create((short)(+1), -2, +3, -4, +5, -6, +7, -8).GetElement(1)
+        );
+
+        Assert.Equal(
+            (int)(-2),
+            Vector128.Create((int)(+1), -2, +3, -4).GetElement(1)
+        );
+
+        Assert.Equal(
+            (long)(-2),
+            Vector128.Create((long)(+1), -2).GetElement(1)
+        );
+
+        Assert.Equal(
+            (float)(-2),
+            Vector128.Create((float)(+1), -2, +3, -4).GetElement(1)
+        );
+
+        Assert.Equal(
+            (double)(-2),
+            Vector128.Create((double)(+1), -2).GetElement(1)
+        );
+    }
+}
diff --git a/src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.csproj b/src/tests/JIT/HardwareIntrinsics/General/ConstantFolding/SimdConstantFoldings.csproj
new file mode 100644 (file)
index 0000000..66d5848
--- /dev/null
@@ -0,0 +1,9 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <BuildAsStandalone>false</BuildAsStandalone>
+    <Optimize>True</Optimize>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+  </ItemGroup>
+</Project>