Optimize Vector<int>.Dot on AVX.
authorsivarv <sivarv@microsoft.com>
Thu, 20 Oct 2016 21:46:00 +0000 (14:46 -0700)
committersivarv <sivarv@microsoft.com>
Fri, 21 Oct 2016 23:38:35 +0000 (16:38 -0700)
Commit migrated from https://github.com/dotnet/coreclr/commit/1c8f29584f5ddebc4ea412283ca3811a0932ed22

src/coreclr/src/jit/emitxarch.cpp
src/coreclr/src/jit/instrsxarch.h
src/coreclr/src/jit/lowerxarch.cpp
src/coreclr/src/jit/simd.cpp
src/coreclr/src/jit/simdcodegenxarch.cpp
src/coreclr/src/jit/simdintrinsiclist.h
src/coreclr/tests/src/JIT/SIMD/VectorDot.cs

index b7c2ae9..626ba41 100644 (file)
@@ -76,9 +76,7 @@ bool emitter::IsThreeOperandBinaryAVXInstruction(instruction ins)
             ins == INS_paddw || ins == INS_paddd || ins == INS_paddq || ins == INS_psubb || ins == INS_psubw ||
             ins == INS_psubd || ins == INS_psubq || ins == INS_pmuludq || ins == INS_pxor || ins == INS_pmaxub ||
             ins == INS_pminub || ins == INS_pmaxsw || ins == INS_pminsw || ins == INS_insertps ||
-            ins == INS_vinsertf128 || ins == INS_punpckldq
-
-            );
+            ins == INS_vinsertf128 || ins == INS_punpckldq || ins == INS_phaddd);
 }
 
 // Returns true if the AVX instruction is a move operator that requires 3 operands.
@@ -105,7 +103,7 @@ bool Is4ByteAVXInstruction(instruction ins)
     return (ins == INS_dpps || ins == INS_dppd || ins == INS_insertps || ins == INS_pcmpeqq || ins == INS_pcmpgtq ||
             ins == INS_vbroadcastss || ins == INS_vbroadcastsd || ins == INS_vpbroadcastb || ins == INS_vpbroadcastw ||
             ins == INS_vpbroadcastd || ins == INS_vpbroadcastq || ins == INS_vextractf128 || ins == INS_vinsertf128 ||
-            ins == INS_pmulld || ins == INS_ptest);
+            ins == INS_pmulld || ins == INS_ptest || ins == INS_phaddd);
 #else
     return false;
 #endif
index 986bf9f..4b32cd4 100644 (file)
@@ -318,6 +318,7 @@ INST3( pcmpeqq,      "pcmpeqq"     , 0, IUM_WR, 0, 0, BAD_CODE,     BAD_CODE, SS
 INST3( pcmpgtq,      "pcmpgtq"     , 0, IUM_WR, 0, 0, BAD_CODE,     BAD_CODE, SSE38(0x37))   // Packed compare 64-bit integers for equality
 INST3( pmulld,       "pmulld"      , 0, IUM_WR, 0, 0, BAD_CODE,     BAD_CODE, SSE38(0x40))   // Packed multiply 32 bit unsigned integers and store lower 32 bits of each result
 INST3( ptest,        "ptest"       , 0, IUM_WR, 0, 0, BAD_CODE,     BAD_CODE, SSE38(0x17))   // Packed logical compare
+INST3( phaddd,       "phaddd"      , 0, IUM_WR, 0, 0, BAD_CODE,     BAD_CODE, SSE38(0x02))   // Packed horizontal add
 INST3(LAST_SSE4_INSTRUCTION, "LAST_SSE4_INSTRUCTION",  0, IUM_WR, 0, 0, BAD_CODE, BAD_CODE, BAD_CODE)
 
 INST3(FIRST_AVX_INSTRUCTION, "FIRST_AVX_INSTRUCTION",  0, IUM_WR, 0, 0, BAD_CODE, BAD_CODE, BAD_CODE)
index f0258a2..11a50cc 100644 (file)
@@ -2734,17 +2734,38 @@ void Lowering::TreeNodeInfoInitSIMD(GenTree* tree)
             break;
 
         case SIMDIntrinsicDotProduct:
-            if ((comp->getSIMDInstructionSet() == InstructionSet_SSE2) ||
-                (simdTree->gtOp.gtOp1->TypeGet() == TYP_SIMD32))
+            // Float/Double vectors:
+            // For SSE, or AVX with 32-byte vectors, we also need an internal register
+            // as scratch. Further we need the targetReg and internal reg to be distinct
+            // registers. Note that if this is a TYP_SIMD16 or smaller on AVX, then we
+            // don't need a tmpReg.
+            //
+            // 32-byte integer vector on AVX:
+            // will take advantage of phaddd, which operates only on 128-bit xmm reg.
+            // This would need 2 internal registers since targetReg is an int type
+            // register.
+            //
+            // See genSIMDIntrinsicDotProduct() for details on code sequence generated
+            // and the need for scratch registers.
+            if (varTypeIsFloating(simdTree->gtSIMDBaseType))
             {
-                // For SSE, or AVX with 32-byte vectors, we also need an internal register as scratch.
-                // Further we need the targetReg and internal reg to be distinct registers.
-                // Note that if this is a TYP_SIMD16 or smaller on AVX, then we don't need a tmpReg.
-                //
-                // See genSIMDIntrinsicDotProduct() for details on code sequence generated and
-                // the need for scratch registers.
-                info->internalFloatCount     = 1;
-                info->isInternalRegDelayFree = true;
+                if ((comp->getSIMDInstructionSet() == InstructionSet_SSE2) ||
+                    (simdTree->gtOp.gtOp1->TypeGet() == TYP_SIMD32))
+                {
+                    info->internalFloatCount     = 1;
+                    info->isInternalRegDelayFree = true;
+                    info->setInternalCandidates(lsra, lsra->allSIMDRegs());
+                }
+                // else don't need scratch reg(s).
+            }
+            else
+            {
+                assert(simdTree->gtSIMDBaseType == TYP_INT && comp->canUseAVX());
+
+                // No need to set isInternalRegDelayFree since targetReg is a
+                // an int type reg and guaranteed to be different from xmm/ymm
+                // regs.
+                info->internalFloatCount = 2;
                 info->setInternalCandidates(lsra, lsra->allSIMDRegs());
             }
             info->srcCount = 2;
index a9450b7..739b206 100644 (file)
@@ -2367,15 +2367,14 @@ GenTreePtr Compiler::impSIMDIntrinsic(OPCODE                opcode,
 
         case SIMDIntrinsicDotProduct:
         {
-#if defined(_TARGET_AMD64_) && defined(DEBUG)
-            // Right now dot product is supported only for float vectors.
-            // See SIMDIntrinsicList.h for supported base types for this intrinsic.
-            if (!varTypeIsFloating(baseType))
+#ifdef _TARGET_AMD64_
+            // Right now dot product is supported only for float/double vectors and
+            // int vectors on AVX.
+            if (!varTypeIsFloating(baseType) && !(baseType == TYP_INT && canUseAVX()))
             {
-                assert(!"Dot product on integer type vectors not supported");
                 return nullptr;
             }
-#endif //_TARGET_AMD64_ && DEBUG
+#endif // _TARGET_AMD64_
 
             // op1 is a SIMD variable that is the first source and also "this" arg.
             // op2 is a SIMD variable which is the second source.
index dcded21..94e6d45 100644 (file)
@@ -1303,27 +1303,52 @@ void CodeGen::genSIMDIntrinsicDotProduct(GenTreeSIMD* simdNode)
     regNumber targetReg    = simdNode->gtRegNum;
     assert(targetReg != REG_NA);
 
-    // DotProduct is only supported on floating point types.
     var_types targetType = simdNode->TypeGet();
     assert(targetType == baseType);
-    assert(varTypeIsFloating(baseType));
 
     genConsumeOperands(simdNode);
-    regNumber op1Reg = op1->gtRegNum;
-    regNumber op2Reg = op2->gtRegNum;
+    regNumber op1Reg  = op1->gtRegNum;
+    regNumber op2Reg  = op2->gtRegNum;
+    regNumber tmpReg1 = REG_NA;
+    regNumber tmpReg2 = REG_NA;
 
-    regNumber tmpReg = REG_NA;
-    // For SSE, or AVX with 32-byte vectors, we need an additional Xmm register as scratch.
-    // However, it must be distinct from targetReg, so we request two from the register allocator.
-    // Note that if this is a TYP_SIMD16 or smaller on AVX, then we don't need a tmpReg.
-    if ((compiler->getSIMDInstructionSet() == InstructionSet_SSE2) || (simdEvalType == TYP_SIMD32))
+    // Dot product intrinsic is supported only on float/double vectors
+    // and 32-byte int vectors on AVX.
+    //
+    // Float/Double Vectors:
+    // For SSE, or AVX with 32-byte vectors, we need an additional Xmm register
+    // different from targetReg as scratch. Note that if this is a TYP_SIMD16 or
+    // smaller on AVX, then we don't need a tmpReg.
+    //
+    // 32-byte integer vector on AVX: we need two additional Xmm registers
+    // different from targetReg as scratch.
+    if (varTypeIsFloating(baseType))
     {
-        assert(simdNode->gtRsvdRegs != RBM_NONE);
-        assert(genCountBits(simdNode->gtRsvdRegs) == 1);
+        if ((compiler->getSIMDInstructionSet() == InstructionSet_SSE2) || (simdEvalType == TYP_SIMD32))
+        {
+            assert(simdNode->gtRsvdRegs != RBM_NONE);
+            assert(genCountBits(simdNode->gtRsvdRegs) == 1);
 
-        tmpReg = genRegNumFromMask(simdNode->gtRsvdRegs);
-        assert(tmpReg != REG_NA);
-        assert(tmpReg != targetReg);
+            tmpReg1 = genRegNumFromMask(simdNode->gtRsvdRegs);
+            assert(tmpReg1 != REG_NA);
+            assert(tmpReg1 != targetReg);
+        }
+        else
+        {
+            assert(simdNode->gtRsvdRegs == RBM_NONE);
+        }
+    }
+    else
+    {
+        assert(baseType == TYP_INT);
+        assert(compiler->getSIMDInstructionSet() == InstructionSet_AVX);
+
+        // Must have reserved 2 scratch registers.
+        assert(simdNode->gtRsvdRegs != RBM_NONE);
+        assert(genCountBits(simdNode->gtRsvdRegs) == 2);
+        regMaskTP tmpRegMask = genFindLowestBit(simdNode->gtRsvdRegs);
+        tmpReg1              = genRegNumFromMask(tmpRegMask);
+        tmpReg2              = genRegNumFromMask(simdNode->gtRsvdRegs & ~tmpRegMask);
     }
 
     if (compiler->getSIMDInstructionSet() == InstructionSet_SSE2)
@@ -1344,7 +1369,7 @@ void CodeGen::genSIMDIntrinsicDotProduct(GenTreeSIMD* simdNode)
         }
 
         // DotProduct(v1, v2)
-        // Here v0 = targetReg, v1 = op1Reg, v2 = op2Reg and tmp = tmpReg
+        // Here v0 = targetReg, v1 = op1Reg, v2 = op2Reg and tmp = tmpReg1
         if (baseType == TYP_FLOAT)
         {
             // v0 = v1 * v2
@@ -1360,80 +1385,115 @@ void CodeGen::genSIMDIntrinsicDotProduct(GenTreeSIMD* simdNode)
             //                                                // HADDPS.
             //
             inst_RV_RV(INS_mulps, targetReg, op2Reg);
-            inst_RV_RV(INS_movaps, tmpReg, targetReg);
-            inst_RV_RV_IV(INS_shufps, EA_16BYTE, tmpReg, tmpReg, 0xb1);
-            inst_RV_RV(INS_addps, targetReg, tmpReg);
-            inst_RV_RV(INS_movaps, tmpReg, targetReg);
-            inst_RV_RV_IV(INS_shufps, EA_16BYTE, tmpReg, tmpReg, 0x1b);
-            inst_RV_RV(INS_addps, targetReg, tmpReg);
+            inst_RV_RV(INS_movaps, tmpReg1, targetReg);
+            inst_RV_RV_IV(INS_shufps, EA_16BYTE, tmpReg1, tmpReg1, 0xb1);
+            inst_RV_RV(INS_addps, targetReg, tmpReg1);
+            inst_RV_RV(INS_movaps, tmpReg1, targetReg);
+            inst_RV_RV_IV(INS_shufps, EA_16BYTE, tmpReg1, tmpReg1, 0x1b);
+            inst_RV_RV(INS_addps, targetReg, tmpReg1);
         }
-        else if (baseType == TYP_DOUBLE)
+        else
         {
+            assert(baseType == TYP_DOUBLE);
+
             // v0 = v1 * v2
             // tmp = v0                                       // v0  = (1, 0) - each element is given by its position
             // tmp = shuffle(tmp, tmp, Shuffle(0,1))          // tmp = (0, 1)
             // v0 = v0 + tmp                                  // v0  = (1+0, 0+1)
             inst_RV_RV(INS_mulpd, targetReg, op2Reg);
-            inst_RV_RV(INS_movaps, tmpReg, targetReg);
-            inst_RV_RV_IV(INS_shufpd, EA_16BYTE, tmpReg, tmpReg, 0x01);
-            inst_RV_RV(INS_addpd, targetReg, tmpReg);
-        }
-        else
-        {
-            unreached();
+            inst_RV_RV(INS_movaps, tmpReg1, targetReg);
+            inst_RV_RV_IV(INS_shufpd, EA_16BYTE, tmpReg1, tmpReg1, 0x01);
+            inst_RV_RV(INS_addpd, targetReg, tmpReg1);
         }
     }
     else
     {
-        // We avoid reg move if either op1Reg == targetReg or op2Reg == targetReg.
-        // Note that this is a duplicate of the code above for SSE, but in the AVX case we can eventually
-        // use the 3-op form, so that we can avoid these copies.
-        // TODO-CQ: Add inst_RV_RV_RV_IV().
-        if (op1Reg == targetReg)
-        {
-            // Best case
-            // nothing to do, we have registers in the right place
-        }
-        else if (op2Reg == targetReg)
-        {
-            op2Reg = op1Reg;
-        }
-        else
-        {
-            inst_RV_RV(ins_Copy(simdType), targetReg, op1Reg, simdEvalType, emitActualTypeSize(simdType));
-        }
+        assert(compiler->getSIMDInstructionSet() == InstructionSet_AVX);
 
-        emitAttr emitSize = emitActualTypeSize(simdEvalType);
-        if (baseType == TYP_FLOAT)
+        if (varTypeIsFloating(baseType))
         {
-            // dpps computes the dot product of the upper & lower halves of the 32-byte register.
-            // Notice that if this is a TYP_SIMD16 or smaller on AVX, then we don't need a tmpReg.
-            inst_RV_RV_IV(INS_dpps, emitSize, targetReg, op2Reg, 0xf1);
-            // If this is TYP_SIMD32, we need to combine the lower & upper results.
-            if (simdEvalType == TYP_SIMD32)
+            // We avoid reg move if either op1Reg == targetReg or op2Reg == targetReg.
+            // Note that this is a duplicate of the code above for SSE, but in the AVX case we can eventually
+            // use the 3-op form, so that we can avoid these copies.
+            // TODO-CQ: Add inst_RV_RV_RV_IV().
+            if (op1Reg == targetReg)
             {
-                getEmitter()->emitIns_R_R_I(INS_vextractf128, EA_32BYTE, tmpReg, targetReg, 0x01);
-                inst_RV_RV(INS_addps, targetReg, tmpReg, targetType, emitTypeSize(targetType));
+                // Best case
+                // nothing to do, we have registers in the right place
+            }
+            else if (op2Reg == targetReg)
+            {
+                op2Reg = op1Reg;
+            }
+            else
+            {
+                inst_RV_RV(ins_Copy(simdType), targetReg, op1Reg, simdEvalType, emitActualTypeSize(simdType));
+            }
+
+            emitAttr emitSize = emitActualTypeSize(simdEvalType);
+            if (baseType == TYP_FLOAT)
+            {
+                // dpps computes the dot product of the upper & lower halves of the 32-byte register.
+                // Notice that if this is a TYP_SIMD16 or smaller on AVX, then we don't need a tmpReg.
+                inst_RV_RV_IV(INS_dpps, emitSize, targetReg, op2Reg, 0xf1);
+                // If this is TYP_SIMD32, we need to combine the lower & upper results.
+                if (simdEvalType == TYP_SIMD32)
+                {
+                    getEmitter()->emitIns_R_R_I(INS_vextractf128, EA_32BYTE, tmpReg1, targetReg, 0x01);
+                    inst_RV_RV(INS_addps, targetReg, tmpReg1, targetType, emitTypeSize(targetType));
+                }
+            }
+            else if (baseType == TYP_DOUBLE)
+            {
+                // On AVX, we have no 16-byte vectors of double.  Note that, if we did, we could use
+                // dppd directly.
+                assert(simdType == TYP_SIMD32);
+
+                // targetReg = targetReg * op2Reg
+                // targetReg = vhaddpd(targetReg, targetReg) ; horizontal sum of lower & upper halves
+                // tmpReg    = vextractf128(targetReg, 1)    ; Moves the upper sum into tempReg
+                // targetReg = targetReg + tmpReg1
+                inst_RV_RV(INS_mulpd, targetReg, op2Reg, simdEvalType, emitActualTypeSize(simdType));
+                inst_RV_RV(INS_haddpd, targetReg, targetReg, simdEvalType, emitActualTypeSize(simdType));
+                getEmitter()->emitIns_R_R_I(INS_vextractf128, EA_32BYTE, tmpReg1, targetReg, 0x01);
+                inst_RV_RV(INS_addpd, targetReg, tmpReg1, targetType, emitTypeSize(targetType));
             }
-        }
-        else if (baseType == TYP_DOUBLE)
-        {
-            // On AVX, we have no 16-byte vectors of double.  Note that, if we did, we could use
-            // dppd directly.
-            assert(simdType == TYP_SIMD32);
-
-            // targetReg = targetReg * op2Reg
-            // targetReg = vhaddpd(targetReg, targetReg) ; horizontal sum of lower & upper halves
-            // tmpReg    = vextractf128(targetReg, 1)    ; Moves the upper sum into tempReg
-            // targetReg = targetReg + tmpReg
-            inst_RV_RV(INS_mulpd, targetReg, op2Reg, simdEvalType, emitActualTypeSize(simdType));
-            inst_RV_RV(INS_haddpd, targetReg, targetReg, simdEvalType, emitActualTypeSize(simdType));
-            getEmitter()->emitIns_R_R_I(INS_vextractf128, EA_32BYTE, tmpReg, targetReg, 0x01);
-            inst_RV_RV(INS_addpd, targetReg, tmpReg, targetType, emitTypeSize(targetType));
         }
         else
         {
-            unreached();
+            // Dot product of 32-byte int vector on AVX.
+            assert(baseType == TYP_INT);
+            assert(simdEvalType == TYP_SIMD32);
+
+            // We need 2 scratch registers.
+            assert(tmpReg1 != REG_NA);
+            assert(tmpReg2 != REG_NA);
+
+            // tmpReg1 = op1 * op2
+            inst_RV_RV_RV(INS_pmulld, tmpReg1, op1Reg, op2Reg, EA_32BYTE);
+
+            // tmpReg2[127..0] = Upper 128-bits of tmpReg1
+            getEmitter()->emitIns_R_R_I(INS_vextractf128, EA_32BYTE, tmpReg2, tmpReg1, 0x01);
+
+            // tmpReg1[127..0] = tmpReg1[127..0] + tmpReg2[127..0]
+            // This will compute
+            //    tmpReg1[0] = op1[0]*op2[0] + op1[4]*op2[4]
+            //    tmpReg1[1] = op1[1]*op2[1] + op1[5]*op2[5]
+            //    tmpReg1[2] = op1[2]*op2[2] + op1[6]*op2[6]
+            //    tmpReg1[4] = op1[4]*op2[4] + op1[7]*op2[7]
+            inst_RV_RV(INS_paddd, tmpReg1, tmpReg2, TYP_SIMD16, EA_16BYTE);
+
+            // This horizontal add would compute
+            //   tmpReg1[0] = targetReg[2] = op1[0]*op2[0] + op1[4]*op2[4] + op1[1]*op2[1] + op1[5]*op2[5]
+            //   tmpReg1[1] = targetReg[3] = op1[2]*op2[2] + op1[6]*op2[6] + op1[4]*op2[4] + op1[7]*op2[7]
+            inst_RV_RV(INS_phaddd, tmpReg1, tmpReg1, TYP_SIMD16, EA_16BYTE);
+
+            // DotProduct(op1, op2) = tmpReg1[0] = tmpReg1[0] + tmpReg1[1]
+            inst_RV_RV(INS_phaddd, tmpReg1, tmpReg1, TYP_SIMD16, EA_16BYTE);
+
+            // TargetReg = integer result from tmpReg1
+            // (Note that for mov_xmm2i, the int register is always in the reg2 position)
+            inst_RV_RV(INS_mov_xmm2i, tmpReg1, targetReg, TYP_INT);
         }
     }
 
index a44fb9d..53dee0d 100644 (file)
@@ -111,7 +111,8 @@ SIMD_INTRINSIC("op_BitwiseOr",              false,       BitwiseOr,
 SIMD_INTRINSIC("op_ExclusiveOr",            false,       BitwiseXor,               "^",                      TYP_STRUCT,     2,      {TYP_STRUCT, TYP_STRUCT, TYP_UNDEF},   {TYP_INT, TYP_FLOAT, TYP_DOUBLE, TYP_LONG, TYP_CHAR, TYP_UBYTE, TYP_BYTE, TYP_SHORT, TYP_UINT, TYP_ULONG})
 
 // Dot Product
-SIMD_INTRINSIC("Dot",                       false,       DotProduct,               "Dot",                    TYP_UNKNOWN,    2,      {TYP_STRUCT, TYP_STRUCT, TYP_UNDEF},   {TYP_FLOAT, TYP_DOUBLE, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF})
+// Is supported only on Vector<int> on AVX.
+SIMD_INTRINSIC("Dot",                       false,       DotProduct,               "Dot",                    TYP_UNKNOWN,    2,      {TYP_STRUCT, TYP_STRUCT, TYP_UNDEF},   {TYP_INT, TYP_FLOAT, TYP_DOUBLE, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF, TYP_UNDEF})
 
 // Select
 SIMD_INTRINSIC("ConditionalSelect",         false,       Select,                   "Select",                 TYP_STRUCT,     3,      {TYP_STRUCT, TYP_STRUCT, TYP_STRUCT},  {TYP_INT, TYP_FLOAT, TYP_DOUBLE, TYP_LONG, TYP_CHAR, TYP_UBYTE, TYP_BYTE, TYP_SHORT, TYP_UINT, TYP_ULONG})
index 2efe79e..22c1493 100644 (file)
@@ -113,12 +113,17 @@ internal partial class VectorTest
         if (VectorDotTest<ulong>.VectorDot(3ul, 2ul, 6ul * (ulong)Vector<ulong>.Count) != Pass) returnVal = Fail;
 
         JitLog jitLog = new JitLog();
-        // Dot is only recognized as an intrinsic for floating point element types.
+        // Dot is only recognized as an intrinsic for floating point element types
+        // and Vector<int> on AVX.
         if (!jitLog.Check("Dot", "Single")) returnVal = Fail;
         if (!jitLog.Check("Dot", "Double")) returnVal = Fail;
         if (!jitLog.Check("System.Numerics.Vector4:Dot")) returnVal = Fail;
         if (!jitLog.Check("System.Numerics.Vector3:Dot")) returnVal = Fail;
         if (!jitLog.Check("System.Numerics.Vector2:Dot")) returnVal = Fail;
+        if (Vector<int>.Count == 8)
+        {
+            if (!jitLog.Check("Dot", "Int32")) returnVal = Fail;
+        }
         jitLog.Dispose();
 
         return returnVal;