Fold const WithElement to CNS_VEC (#86212)
authorJasper <jasper-d@users.noreply.github.com>
Tue, 13 Jun 2023 11:53:59 +0000 (13:53 +0200)
committerGitHub <noreply@github.com>
Tue, 13 Jun 2023 11:53:59 +0000 (13:53 +0200)
Co-authored-by: Egor Bogatov <egorbo@gmail.com>
src/coreclr/jit/valuenum.cpp
src/coreclr/jit/valuenum.h

index da623ad9ae63f1bb2331e66a228f2786a9ba5796..9951fc1625dc506b81b0be940631534dce5ed975 100644 (file)
@@ -2648,7 +2648,8 @@ ValueNum ValueNumStore::VNForFunc(var_types typ, VNFunc func, ValueNum arg0VN, V
     assert(arg0VN != NoVN);
     assert(arg1VN != NoVN);
     assert(arg2VN != NoVN);
-    assert(VNFuncArity(func) == 3);
+    // Some SIMD functions with variable number of arguments are defined with zero arity
+    assert((VNFuncArity(func) == 0) || (VNFuncArity(func) == 3));
 
 #ifdef DEBUG
     // Function arguments carry no exceptions.
@@ -2664,7 +2665,6 @@ ValueNum ValueNumStore::VNForFunc(var_types typ, VNFunc func, ValueNum arg0VN, V
     }
     assert(arg2VN == VNNormalValue(arg2VN));
 #endif
-    assert(VNFuncArity(func) == 3);
 
     ValueNum resultVN;
 
@@ -7813,7 +7813,109 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types      type,
     }
     return VNForFunc(type, func, arg0VN, arg1VN);
 }
+
+ValueNum EvaluateSimdFloatWithElement(ValueNumStore* vns, var_types type, ValueNum arg0VN, int index, float value)
+{
+    assert(vns->IsVNConstant(arg0VN));
+    assert(static_cast<unsigned>(index) < genTypeSize(type) / genTypeSize(TYP_FLOAT));
+
+    switch (type)
+    {
+        case TYP_SIMD8:
+        {
+            simd8_t cnsVec    = vns->GetConstantSimd8(arg0VN);
+            cnsVec.f32[index] = value;
+            return vns->VNForSimd8Con(cnsVec);
+        }
+        case TYP_SIMD12:
+        {
+            simd12_t cnsVec   = vns->GetConstantSimd12(arg0VN);
+            cnsVec.f32[index] = value;
+            return vns->VNForSimd12Con(cnsVec);
+        }
+        case TYP_SIMD16:
+        {
+            simd16_t cnsVec   = vns->GetConstantSimd16(arg0VN);
+            cnsVec.f32[index] = value;
+            return vns->VNForSimd16Con(cnsVec);
+        }
+#if defined TARGET_XARCH
+        case TYP_SIMD32:
+        {
+            simd32_t cnsVec   = vns->GetConstantSimd32(arg0VN);
+            cnsVec.f32[index] = value;
+            return vns->VNForSimd32Con(cnsVec);
+        }
+        case TYP_SIMD64:
+        {
+            simd64_t cnsVec   = vns->GetConstantSimd64(arg0VN);
+            cnsVec.f32[index] = value;
+            return vns->VNForSimd64Con(cnsVec);
+        }
+#endif // TARGET_XARCH
+        default:
+        {
+            unreached();
+        }
+    }
+}
+
+ValueNum ValueNumStore::EvalHWIntrinsicFunTernary(var_types      type,
+                                                  var_types      baseType,
+                                                  NamedIntrinsic ni,
+                                                  VNFunc         func,
+                                                  ValueNum       arg0VN,
+                                                  ValueNum       arg1VN,
+                                                  ValueNum       arg2VN,
+                                                  bool           encodeResultType,
+                                                  ValueNum       resultTypeVN)
+{
+    if (IsVNConstant(arg0VN) && IsVNConstant(arg1VN) && IsVNConstant(arg2VN))
+    {
+
+        switch (ni)
+        {
+            case NI_Vector128_WithElement:
+#ifdef TARGET_ARM64
+            case NI_Vector64_WithElement:
+#else
+            case NI_Vector256_WithElement:
+            case NI_Vector512_WithElement:
 #endif
+            {
+                int index = GetConstantInt32(arg1VN);
+
+                assert(varTypeIsSIMD(type));
+
+                // No meaningful diffs for other base-types.
+                if ((baseType != TYP_FLOAT) || (TypeOfVN(arg0VN) != type) ||
+                    (static_cast<unsigned>(index) >= (genTypeSize(type) / genTypeSize(baseType))))
+                {
+                    break;
+                }
+
+                float value = GetConstantSingle(arg2VN);
+
+                return EvaluateSimdFloatWithElement(this, type, arg0VN, index, value);
+            }
+            default:
+            {
+                break;
+            }
+        }
+    }
+
+    if (encodeResultType)
+    {
+        return VNForFunc(type, func, arg0VN, arg1VN, arg2VN, resultTypeVN);
+    }
+    else
+    {
+        return VNForFunc(type, func, arg0VN, arg1VN, arg2VN);
+    }
+}
+
+#endif // FEATURE_HW_INTRINSICS
 
 ValueNum ValueNumStore::EvalMathFuncUnary(var_types typ, NamedIntrinsic gtMathFN, ValueNum arg0VN)
 {
@@ -11475,9 +11577,11 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
     ValueNumPair excSetPair = ValueNumStore::VNPForEmptyExcSet();
     ValueNumPair normalPair = ValueNumPair();
 
-    if ((tree->GetOperandCount() > 2) || ((JitConfig.JitDisableSimdVN() & 2) == 2))
+    const size_t opCount = tree->GetOperandCount();
+
+    if ((opCount > 3) || (JitConfig.JitDisableSimdVN() & 2) == 2)
     {
-        // TODO-CQ: allow intrinsics with > 2 operands to be properly VN'ed.
+        // TODO-CQ: allow intrinsics with > 3 operands to be properly VN'ed.
         normalPair = vnStore->VNPairForExpr(compCurBB, tree->TypeGet());
 
         for (GenTree* operand : tree->Operands())
@@ -11525,7 +11629,7 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
         const bool isVariableNumArgs = HWIntrinsicInfo::lookupNumArgs(intrinsicId) == -1;
 
         // There are some HWINTRINSICS operations that have zero args, i.e.  NI_Vector128_Zero
-        if (tree->GetOperandCount() == 0)
+        if (opCount == 0)
         {
             // Currently we don't have intrinsics with variable number of args with a parameter-less option.
             assert(!isVariableNumArgs);
@@ -11542,13 +11646,13 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
                 assert(vnStore->VNFuncArity(func) == 0);
             }
         }
-        else // HWINTRINSIC unary or binary operator.
+        else // HWINTRINSIC unary or binary or ternary operator.
         {
             ValueNumPair op1vnp;
             ValueNumPair op1Xvnp;
             getOperandVNs(tree->Op(1), &op1vnp, &op1Xvnp);
 
-            if (tree->GetOperandCount() == 1)
+            if (opCount == 1)
             {
                 ValueNum normalLVN = vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), tree->GetSimdBaseType(),
                                                                       intrinsicId, func, op1vnp.GetLiberal(),
@@ -11567,17 +11671,44 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
                 ValueNumPair op2Xvnp;
                 getOperandVNs(tree->Op(2), &op2vnp, &op2Xvnp);
 
-                ValueNum normalLVN =
-                    vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
-                                                      op1vnp.GetLiberal(), op2vnp.GetLiberal(), encodeResultType,
-                                                      resultTypeVNPair.GetLiberal());
-                ValueNum normalCVN =
-                    vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
-                                                      op1vnp.GetConservative(), op2vnp.GetConservative(),
-                                                      encodeResultType, resultTypeVNPair.GetConservative());
+                if (opCount == 2)
+                {
+                    ValueNum normalLVN =
+                        vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                          op1vnp.GetLiberal(), op2vnp.GetLiberal(), encodeResultType,
+                                                          resultTypeVNPair.GetLiberal());
+                    ValueNum normalCVN =
+                        vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                          op1vnp.GetConservative(), op2vnp.GetConservative(),
+                                                          encodeResultType, resultTypeVNPair.GetConservative());
 
-                normalPair = ValueNumPair(normalLVN, normalCVN);
-                excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
+                    normalPair = ValueNumPair(normalLVN, normalCVN);
+                    excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
+                }
+                else
+                {
+                    assert(opCount == 3);
+
+                    ValueNumPair op3vnp;
+                    ValueNumPair op3Xvnp;
+                    getOperandVNs(tree->Op(3), &op3vnp, &op3Xvnp);
+
+                    ValueNum normalLVN =
+                        vnStore->EvalHWIntrinsicFunTernary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                           op1vnp.GetLiberal(), op2vnp.GetLiberal(),
+                                                           op3vnp.GetLiberal(), encodeResultType,
+                                                           resultTypeVNPair.GetLiberal());
+                    ValueNum normalCVN =
+                        vnStore->EvalHWIntrinsicFunTernary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
+                                                           op1vnp.GetConservative(), op2vnp.GetConservative(),
+                                                           op3vnp.GetConservative(), encodeResultType,
+                                                           resultTypeVNPair.GetConservative());
+
+                    normalPair = ValueNumPair(normalLVN, normalCVN);
+
+                    excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
+                    excSetPair = vnStore->VNPExcSetUnion(excSetPair, op3Xvnp);
+                }
             }
         }
     }
index e63f4dad58fd4bd61677f20cd4be9b1c9c202315..cb7be434314180132b50f65926d7e7c08c7ef92d 100644 (file)
@@ -1184,6 +1184,16 @@ public:
                                       bool           encodeResultType,
                                       ValueNum       resultTypeVN);
 
+    ValueNum EvalHWIntrinsicFunTernary(var_types      type,
+                                       var_types      baseType,
+                                       NamedIntrinsic ni,
+                                       VNFunc         func,
+                                       ValueNum       arg0VN,
+                                       ValueNum       arg1VN,
+                                       ValueNum       arg2VN,
+                                       bool           encodeResultType,
+                                       ValueNum       resultTypeVN);
+
     // Returns "true" iff "vn" represents a function application.
     bool IsVNFunc(ValueNum vn);