Ensure that NI_Vector_Dot is always handled as TYP_SIMD (#88447)
authorTanner Gooding <tagoo@outlook.com>
Mon, 10 Jul 2023 04:16:01 +0000 (21:16 -0700)
committerGitHub <noreply@github.com>
Mon, 10 Jul 2023 04:16:01 +0000 (21:16 -0700)
* Ensure that NI_Vector_Dot is always handled as TYP_SIMD

* Apply formatting patch

* Ensure simdType is passed in to gtNewSimdDotProdNode

* Apply suggestions from code review

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
---------

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
src/coreclr/jit/gentree.cpp
src/coreclr/jit/hwintrinsicarm64.cpp
src/coreclr/jit/hwintrinsicxarch.cpp
src/coreclr/jit/lowerarmarch.cpp
src/coreclr/jit/lowerxarch.cpp
src/coreclr/jit/morph.cpp
src/coreclr/jit/simdashwintrinsic.cpp

index ee9be97..66943d2 100644 (file)
@@ -22052,9 +22052,7 @@ GenTree* Compiler::gtNewSimdDotProdNode(
     assert(op2->TypeIs(simdType));
 
     var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
-
-    // We support the return type being a SIMD for floating-point as a special optimization
-    assert(varTypeIsArithmetic(type) || (varTypeIsSIMD(type) && varTypeIsFloating(simdBaseType)));
+    assert(varTypeIsSIMD(type));
 
     NamedIntrinsic intrinsic = NI_Illegal;
 
index 370c346..f2a6fa7 100644 (file)
@@ -881,7 +881,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic        intrinsic,
                 op2 = impSIMDPopStack();
                 op1 = impSIMDPopStack();
 
-                retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
+                retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
+                retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
             }
             break;
         }
index 270dffb..fc81d05 100644 (file)
@@ -1778,7 +1778,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic        intrinsic,
             op2 = impSIMDPopStack();
             op1 = impSIMDPopStack();
 
-            retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
+            retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
+            retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
             break;
         }
 
index 25d07c5..d4f2d8c 100644 (file)
@@ -1628,9 +1628,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
     assert(varTypeIsSIMD(simdType));
     assert(varTypeIsArithmetic(simdBaseType));
     assert(simdSize != 0);
-
-    // We support the return type being a SIMD for floating-point as a special optimization
-    assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
+    assert(varTypeIsSIMD(node));
 
     GenTree* op1 = node->Op(1);
     GenTree* op2 = node->Op(2);
@@ -1647,7 +1645,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
 
         // For 12 byte SIMD, we need to clear the upper 4 bytes:
         //   idx  =    CNS_INT       int    0x03
-        //   tmp1 = *  CNS_DLB       float  0.0
+        //   tmp1 = *  CNS_DBL       float  0.0
         //          /--*  op1  simd16
         //          +--*  idx  int
         //          +--*  tmp1 simd16
@@ -1887,34 +1885,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
         }
     }
 
-    if (varTypeIsSIMD(node->gtType))
-    {
-        // We're producing a vector result, so just return the result directly
-
-        LIR::Use use;
-
-        if (BlockRange().TryGetUse(node, &use))
-        {
-            use.ReplaceWith(tmp2);
-        }
+    // We're producing a vector result, so just return the result directly
+    LIR::Use use;
 
-        BlockRange().Remove(node);
-        return tmp2->gtNext;
-    }
-    else
+    if (BlockRange().TryGetUse(node, &use))
     {
-        // We will be constructing the following parts:
-        //   ...
-        //          /--*  tmp2 simd16
-        //   node = *  HWINTRINSIC   simd16 T ToScalar
-
-        // This is roughly the following managed code:
-        //   ...
-        //   return tmp2.ToScalar();
-
-        node->ResetHWIntrinsicId((simdSize == 8) ? NI_Vector64_ToScalar : NI_Vector128_ToScalar, tmp2);
-        return LowerNode(node);
+        use.ReplaceWith(tmp2);
     }
+
+    BlockRange().Remove(node);
+    return tmp2->gtNext;
 }
 #endif // FEATURE_HW_INTRINSICS
 
index 4cd240e..c5514ce 100644 (file)
@@ -4428,9 +4428,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
     assert(varTypeIsSIMD(simdType));
     assert(varTypeIsArithmetic(simdBaseType));
     assert(simdSize != 0);
-
-    // We support the return type being a SIMD for floating-point as a special optimization
-    assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
+    assert(varTypeIsSIMD(node));
 
     GenTree* op1 = node->Op(1);
     GenTree* op2 = node->Op(2);
@@ -4479,15 +4477,12 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
                 //   tmp2 = *  HWINTRINSIC   simd16 T GetUpper
                 //          /--*  tmp1 simd16
                 //          +--*  tmp2 simd16
-                //   tmp3 = *  HWINTRINSIC   simd16 T Add
-                //          /--*  tmp3 simd16
-                //   node = *  HWINTRINSIC   simd16 T ToScalar
+                //   node = *  HWINTRINSIC   simd16 T Add
 
                 // This is roughly the following managed code:
                 //   var tmp1 = Avx.DotProduct(op1, op2, 0xFF);
                 //   var tmp2 = tmp1.GetUpper();
-                //   var tmp3 = Sse.Add(tmp1, tmp2);
-                //   return tmp3.ToScalar();
+                //   return Sse.Add(tmp1, tmp2);
 
                 idx = comp->gtNewIconNode(0xF1, TYP_INT);
                 BlockRange().InsertBefore(node, idx);
@@ -4515,13 +4510,17 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
 
                 tmp2 = comp->gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, tmp3, tmp1, simdBaseJitType, 16);
                 BlockRange().InsertAfter(tmp1, tmp2);
-                LowerNode(tmp2);
 
-                node->SetSimdSize(16);
+                // We're producing a vector result, so just return the result directly
+                LIR::Use use;
 
-                node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp2);
+                if (BlockRange().TryGetUse(node, &use))
+                {
+                    use.ReplaceWith(tmp2);
+                }
 
-                return LowerNode(node);
+                BlockRange().Remove(node);
+                return LowerNode(tmp2);
             }
 
             case TYP_DOUBLE:
@@ -5043,34 +5042,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
         tmp1 = tmp2;
     }
 
-    if (varTypeIsSIMD(node->gtType))
-    {
-        // We're producing a vector result, so just return the result directly
-
-        LIR::Use use;
-
-        if (BlockRange().TryGetUse(node, &use))
-        {
-            use.ReplaceWith(tmp1);
-        }
+    // We're producing a vector result, so just return the result directly
+    LIR::Use use;
 
-        BlockRange().Remove(node);
-        return tmp1->gtNext;
-    }
-    else
+    if (BlockRange().TryGetUse(node, &use))
     {
-        // We will be constructing the following parts:
-        //   ...
-        //          /--*  tmp1 simd16
-        //   node = *  HWINTRINSIC   simd16 T ToScalar
-
-        // This is roughly the following managed code:
-        //   ...
-        //   return tmp1.ToScalar();
-
-        node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp1);
-        return LowerNode(node);
+        use.ReplaceWith(tmp1);
     }
+
+    BlockRange().Remove(node);
+    return tmp1->gtNext;
 }
 
 //----------------------------------------------------------------------------------------------
index 0c90145..d2a2865 100644 (file)
@@ -10670,39 +10670,45 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
 #endif // TARGET_ARM64
         case NI_Vector128_Create:
         {
-            // The `Dot` API returns a scalar. However, many common usages require it to
-            // be then immediately broadcast back to a vector so that it can be used in
-            // a subsequent operation. One of the most common is normalizing a vector
+            // The managed `Dot` API returns a scalar. However, many common usages require
+            // it to be then immediately broadcast back to a vector so that it can be used
+            // in a subsequent operation. One of the most common is normalizing a vector
             // which is effectively `value / value.Length` where `Length` is
-            // `Sqrt(Dot(value, value))`
+            // `Sqrt(Dot(value, value))`. Because of this, and because of how a lot of
+            // hardware works, we treat `NI_Vector_Dot` as returning a SIMD type and then
+            // also wrap it in `ToScalar` where required.
             //
             // In order to ensure that developers can still utilize this efficiently, we
-            // will look for two common patterns:
+            // then look for four common patterns:
             //  * Create(Dot(..., ...))
             //  * Create(Sqrt(Dot(..., ...)))
+            //  * Create(ToScalar(Dot(..., ...)))
+            //  * Create(ToScalar(Sqrt(Dot(..., ...))))
             //
-            // When these exist, we'll avoid converting to a scalar at all and just
-            // keep everything as a vector. However, we only do this for Vector64/Vector128
-            // and only for float/double.
+            // When these exist, we'll avoid converting to a scalar and hence, avoid broadcasting
+            // the value back into a vector. Instead we'll just keep everything as a vector.
             //
-            // We don't do this for Vector256 since that is xarch only and doesn't trivially
-            // support operations which cross the upper and lower 128-bit lanes
+            // We only do this for Vector64/Vector128 today. We could expand this more in
+            // the future but it would require additional hand handling for Vector256
+            // (since a 256-bit result requires more work). We do some integer handling
+            // when the value is trivially replicated to all elements without extra work.
 
             if (node->GetOperandCount() != 1)
             {
                 break;
             }
 
-            if (!varTypeIsFloating(node->GetSimdBaseType()))
-            {
-                break;
-            }
-
-            GenTree* op1  = node->Op(1);
-            GenTree* sqrt = nullptr;
+            GenTree* op1      = node->Op(1);
+            GenTree* sqrt     = nullptr;
+            GenTree* toScalar = nullptr;
 
             if (op1->OperIs(GT_INTRINSIC))
             {
+                if (!varTypeIsFloating(node->GetSimdBaseType()))
+                {
+                    break;
+                }
+
                 if (op1->AsIntrinsic()->gtIntrinsicName != NI_System_Math_Sqrt)
                 {
                     break;
@@ -10720,6 +10726,24 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
             GenTreeHWIntrinsic* hwop1 = op1->AsHWIntrinsic();
 
 #if defined(TARGET_ARM64)
+            if ((hwop1->GetHWIntrinsicId() == NI_Vector64_ToScalar) ||
+                (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar))
+#else
+            if (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar)
+#endif
+            {
+                op1 = hwop1->Op(1);
+
+                if (!op1->OperIs(GT_HWINTRINSIC))
+                {
+                    break;
+                }
+
+                toScalar = hwop1;
+                hwop1    = op1->AsHWIntrinsic();
+            }
+
+#if defined(TARGET_ARM64)
             if ((hwop1->GetHWIntrinsicId() != NI_Vector64_Dot) && (hwop1->GetHWIntrinsicId() != NI_Vector128_Dot))
 #else
             if (hwop1->GetHWIntrinsicId() != NI_Vector128_Dot)
@@ -10728,13 +10752,16 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
                 break;
             }
 
-            unsigned  simdSize = node->GetSimdSize();
-            var_types simdType = getSIMDTypeForSize(simdSize);
-
-            hwop1->gtType = simdType;
+            if (toScalar != nullptr)
+            {
+                DEBUG_DESTROY_NODE(toScalar);
+            }
 
             if (sqrt != nullptr)
             {
+                unsigned  simdSize = node->GetSimdSize();
+                var_types simdType = getSIMDTypeForSize(simdSize);
+
                 node = gtNewSimdSqrtNode(simdType, hwop1, node->GetSimdBaseJitType(), simdSize)->AsHWIntrinsic();
                 DEBUG_DESTROY_NODE(sqrt);
             }
index f0951c7..f7c69b3 100644 (file)
@@ -1065,8 +1065,7 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                     vecCon->gtSimdVal.f32[3] = +1.0f;
 
                     GenTree* conjugate = gtNewSimdBinOpNode(GT_MUL, retType, op1, vecCon, simdBaseJitType, simdSize);
-
-                    op1 = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);
+                    op1                = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);
 
                     return gtNewSimdBinOpNode(GT_DIV, retType, conjugate, op1, simdBaseJitType, simdSize);
                 }
@@ -1095,7 +1094,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                     op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
                                        nullptr DEBUGARG("Clone op1 for vector length squared"));
 
-                    return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
+                    op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
+                    return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
                 }
 
                 case NI_VectorT128_Load:
@@ -1174,7 +1174,6 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                                              nullptr DEBUGARG("Clone op1 for vector normalize (2)"));
 
                     op1 = gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
-
                     op1 = gtNewSimdSqrtNode(retType, op1, simdBaseJitType, simdSize);
 
                     return gtNewSimdBinOpNode(GT_DIV, retType, clonedOp2, op1, simdBaseJitType, simdSize);
@@ -1462,7 +1461,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                     op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
                                        nullptr DEBUGARG("Clone diff for vector distance squared"));
 
-                    return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
+                    op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
+                    return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
                 }
 
                 case NI_Quaternion_Divide:
@@ -1492,7 +1492,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                 case NI_VectorT256_Dot:
 #endif // TARGET_XARCH
                 {
-                    return gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
+                    op1 = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
+                    return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
                 }
 
                 case NI_VectorT128_Equals: