Optimize FMA codegen base on the overwritten (#58196)
authorweilinwa <weilin.wang@intel.com>
Wed, 1 Dec 2021 14:47:13 +0000 (06:47 -0800)
committerGitHub <noreply@github.com>
Wed, 1 Dec 2021 14:47:13 +0000 (06:47 -0800)
* Optimize FMA codegen base on the overwritten

* Improve function/var names

* Add assertions

* Get use of FMA with TryGetUse

* Decide FMA form with two conditions, OverwrittenOpNum and isContained

* Fix op reg error in codegen

* Decide form using lastUse and isContained in no overwritten case

* Clean up code

* Separate default case overwrittenOpNum==0

* Apply format patch

* Change variable and function names

* Update regOptional for op1 and resolve some other comments

* Optimize FMA codegen base on the overwritten

* Improve function/var names

* Add assertions

* Get use of FMA with TryGetUse

* Decide FMA form with two conditions, OverwrittenOpNum and isContained

* Fix op reg error in codegen

* Decide form using lastUse and isContained in no overwritten case

* Clean up code

* Separate default case overwrittenOpNum==0

* Apply format patch

* Change variable and function names

* Update regOptional for op1 and resolve some other comments

* Change var names

* Fix jit format

* Fix build node error for op1 is regOptional

* Use targetReg instead of GetResultOpNumForFMA in codegen

* Update variable names

* Refactor lsra to solve lastUse status changed caused assertion failure

* Add check to prioritize contained op in lsra

* Update for jit format

* Simplify code

* Resolve comments

* Comment out assert because of lastUse change

* Fix some copiesUpperBits related errors

* Update src/coreclr/jit/lsraxarch.cpp

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
* Add link to the new issue

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
src/coreclr/jit/gentree.cpp
src/coreclr/jit/gentree.h
src/coreclr/jit/hwintrinsiccodegenxarch.cpp
src/coreclr/jit/lowerxarch.cpp
src/coreclr/jit/lsraxarch.cpp

index 09e596a..00c58d7 100644 (file)
@@ -21898,6 +21898,53 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const
     }
 }
 
+#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
+//------------------------------------------------------------------------
+// GetResultOpNumForFMA: check if the result is written into one of the operands.
+// In the case that none of the operand is overwritten, check if any of them is lastUse.
+//
+// Return Value:
+//     The operand number overwritten or lastUse. 0 is the default value, where the result is written into
+//      a destination that is not one of the source operands and there is no last use op.
+//
+unsigned GenTreeHWIntrinsic::GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
+{
+    // only FMA intrinsic node should call into this function
+    assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA);
+    if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
+    {
+        // For store_lcl_var, check if any op is overwritten
+
+        GenTreeLclVarCommon* overwritten       = use->AsLclVarCommon();
+        unsigned             overwrittenLclNum = overwritten->GetLclNum();
+        if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
+        {
+            return 1;
+        }
+        else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
+        {
+            return 2;
+        }
+        else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
+        {
+            return 3;
+        }
+    }
+
+    // If no overwritten op, check if there is any last use op
+    // https://github.com/dotnet/runtime/issues/62215
+
+    if (op1->OperIs(GT_LCL_VAR) && op1->IsLastUse(0))
+        return 1;
+    else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0))
+        return 2;
+    else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0))
+        return 3;
+
+    return 0;
+}
+#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS
+
 #ifdef TARGET_ARM
 //------------------------------------------------------------------------
 // IsOffsetMisaligned: check if the field needs a special handling on arm.
index 553ed29..1a08650 100644 (file)
@@ -5526,6 +5526,7 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
     {
         return (gtFlags & GTF_SIMDASHW_OP) != 0;
     }
+    unsigned GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3);
 
     NamedIntrinsic GetHWIntrinsicId() const;
 
index 2e6f018..bb6a6da 100644 (file)
@@ -2034,67 +2034,82 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
     NamedIntrinsic intrinsicId = node->GetHWIntrinsicId();
     var_types      baseType    = node->GetSimdBaseType();
     emitAttr       attr        = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize()));
-    instruction    ins         = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
+    instruction    ins         = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); // 213 form
+    instruction    _132form    = (instruction)(ins - 1);
+    instruction    _231form    = (instruction)(ins + 1);
     GenTree*       op1         = node->Op(1);
     GenTree*       op2         = node->Op(2);
     GenTree*       op3         = node->Op(3);
-    regNumber      targetReg   = node->GetRegNum();
+
+    regNumber targetReg = node->GetRegNum();
 
     genConsumeMultiOpOperands(node);
 
-    regNumber op1Reg;
-    regNumber op2Reg;
+    regNumber op1NodeReg = op1->GetRegNum();
+    regNumber op2NodeReg = op2->GetRegNum();
+    regNumber op3NodeReg = op3->GetRegNum();
+
+    GenTree* emitOp1 = op1;
+    GenTree* emitOp2 = op2;
+    GenTree* emitOp3 = op3;
 
-    bool       isCommutative   = false;
     const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);
 
     // Intrinsics with CopyUpperBits semantics cannot have op1 be contained
     assert(!copiesUpperBits || !op1->isContained());
 
-    if (op2->isContained() || op2->isUsedFromSpillTemp())
+    if (op1->isContained() || op1->isUsedFromSpillTemp())
     {
-        // 132 form: op1 = (op1 * op3) + [op2]
-
-        ins    = (instruction)(ins - 1);
-        op1Reg = op1->GetRegNum();
-        op2Reg = op3->GetRegNum();
-        op3    = op2;
+        if (targetReg == op2NodeReg)
+        {
+            std::swap(emitOp1, emitOp2);
+            // op2 = ([op1] * op2) + op3
+            // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
+            ins = _132form;
+            std::swap(emitOp2, emitOp3);
+        }
+        else
+        {
+            // targetReg == op3NodeReg or targetReg == ?
+            // op3 = ([op1] * op2) + op3
+            // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
+            ins = _231form;
+            std::swap(emitOp1, emitOp3);
+        }
     }
-    else if (op1->isContained() || op1->isUsedFromSpillTemp())
+    else if (op2->isContained() || op2->isUsedFromSpillTemp())
     {
-        // 231 form: op3 = (op2 * op3) + [op1]
-
-        ins    = (instruction)(ins + 1);
-        op1Reg = op3->GetRegNum();
-        op2Reg = op2->GetRegNum();
-        op3    = op1;
+        if (!copiesUpperBits && (targetReg == op3NodeReg))
+        {
+            // op3 = (op1 * [op2]) + op3
+            // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
+            ins = _231form;
+            std::swap(emitOp1, emitOp3);
+        }
+        else
+        {
+            // targetReg == op1NodeReg or targetReg == ?
+            // op1 = (op1 * [op2]) + op3
+            // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
+            ins = _132form;
+        }
+        std::swap(emitOp2, emitOp3);
     }
     else
     {
-        // 213 form: op1 = (op2 * op1) + [op3]
-
-        op1Reg = op1->GetRegNum();
-        op2Reg = op2->GetRegNum();
-
-        isCommutative = !copiesUpperBits;
-    }
-
-    if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg))
-    {
-        assert(node->isRMWHWIntrinsic(compiler));
-
-        // We have "reg2 = (reg1 * reg2) +/- op3" where "reg1 != reg2" on a RMW intrinsic.
-        //
-        // For non-commutative intrinsics, we should have ensured that op2 was marked
-        // delay free in order to prevent it from getting assigned the same register
-        // as target. However, for commutative intrinsics, we can just swap the operands
-        // in order to have "reg2 = reg2 op reg1" which will end up producing the right code.
-
-        op2Reg = op1Reg;
-        op1Reg = targetReg;
+        // targetReg could be op1NodeReg, op2NodeReg, or not equal to any op
+        // op1 = (op1 * op2) + [op3] or op2 = (op1 * op2) + [op3]
+        // ? = (op1 * op2) + [op3] or ? = (op1 * op2) + op3
+        // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
+        if (!copiesUpperBits && (targetReg == op2NodeReg))
+        {
+            // op2 = (op1 * op2) + [op3]
+            // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
+            std::swap(emitOp1, emitOp2);
+        }
     }
 
-    genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3);
+    genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3);
     genProduceReg(node);
 }
 
index c1b55e9..548a603 100644 (file)
@@ -6000,40 +6000,53 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
                 {
                     if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar))
                     {
-                        bool supportsRegOptional = false;
+                        bool     supportsOp1RegOptional = false;
+                        bool     supportsOp2RegOptional = false;
+                        bool     supportsOp3RegOptional = false;
+                        unsigned resultOpNum            = 0;
+                        LIR::Use use;
+                        GenTree* user = nullptr;
+
+                        if (BlockRange().TryGetUse(node, &use))
+                        {
+                            user = use.User();
+                        }
+                        resultOpNum = node->GetResultOpNumForFMA(user, op1, op2, op3);
+
+                        // Prioritize Containable op. Check if any one of the op is containable first.
+                        // Set op regOptional only if none of them is containable.
 
-                        if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
+                        // Prefer to make op3 contained,
+                        if (resultOpNum != 3 && IsContainableHWIntrinsicOp(node, op3, &supportsOp3RegOptional))
                         {
-                            // 213 form: op1 = (op2 * op1) + [op3]
+                            // result = (op1 * op2) + [op3]
                             MakeSrcContained(node, op3);
                         }
-                        else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
+                        else if (resultOpNum != 2 && IsContainableHWIntrinsicOp(node, op2, &supportsOp2RegOptional))
                         {
-                            // 132 form: op1 = (op1 * op3) + [op2]
+                            // result = (op1 * [op2]) + op3
                             MakeSrcContained(node, op2);
                         }
-                        else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional))
+                        else if (resultOpNum != 1 && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId) &&
+                                 IsContainableHWIntrinsicOp(node, op1, &supportsOp1RegOptional))
                         {
-                            // Intrinsics with CopyUpperBits semantics cannot have op1 be contained
-
-                            if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
-                            {
-                                // 231 form: op3 = (op2 * op3) + [op1]
-                                MakeSrcContained(node, op1);
-                            }
+                            // result = ([op1] * op2) + op3
+                            MakeSrcContained(node, op1);
                         }
-                        else
+                        else if (supportsOp3RegOptional)
                         {
-                            assert(supportsRegOptional);
-
-                            // TODO-XArch-CQ: Technically any one of the three operands can
-                            //                be reg-optional. With a limitation on op1 where
-                            //                it can only be so if CopyUpperBits is off.
-                            //                https://github.com/dotnet/runtime/issues/6358
-
-                            // 213 form: op1 = (op2 * op1) + op3
+                            assert(resultOpNum != 3);
                             op3->SetRegOptional();
                         }
+                        else if (supportsOp2RegOptional)
+                        {
+                            assert(resultOpNum != 2);
+                            op2->SetRegOptional();
+                        }
+                        else if (supportsOp1RegOptional)
+                        {
+                            op1->SetRegOptional();
+                        }
                     }
                     else
                     {
index 2926f54..ca16960 100644 (file)
@@ -2272,48 +2272,93 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree)
 
                 const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);
 
-                // Intrinsics with CopyUpperBits semantics cannot have op1 be contained
-                assert(!copiesUpperBits || !op1->isContained());
+                unsigned resultOpNum = 0;
+                LIR::Use use;
+                GenTree* user = nullptr;
 
-                if (op2->isContained())
+                if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use))
                 {
-                    // 132 form: op1 = (op1 * op3) + [op2]
+                    user = use.User();
+                }
+                resultOpNum = intrinsicTree->GetResultOpNumForFMA(user, op1, op2, op3);
 
-                    tgtPrefUse = BuildUse(op1);
+                unsigned containedOpNum = 0;
 
-                    srcCount += 1;
-                    srcCount += BuildOperandUses(op2);
-                    srcCount += BuildDelayFreeUses(op3, op1);
+                // containedOpNum remains 0 when no operand is contained or regOptional
+                if (op1->isContained() || op1->IsRegOptional())
+                {
+                    containedOpNum = 1;
                 }
-                else if (op1->isContained())
+                else if (op2->isContained() || op2->IsRegOptional())
                 {
-                    // 231 form: op3 = (op2 * op3) + [op1]
-
-                    tgtPrefUse = BuildUse(op3);
-
-                    srcCount += BuildOperandUses(op1);
-                    srcCount += BuildDelayFreeUses(op2, op1);
-                    srcCount += 1;
+                    containedOpNum = 2;
                 }
-                else
+                else if (op3->isContained() || op3->IsRegOptional())
                 {
-                    // 213 form: op1 = (op2 * op1) + [op3]
+                    containedOpNum = 3;
+                }
 
-                    tgtPrefUse = BuildUse(op1);
-                    srcCount += 1;
+                GenTree* emitOp1 = op1;
+                GenTree* emitOp2 = op2;
+                GenTree* emitOp3 = op3;
 
-                    if (copiesUpperBits)
+                // Intrinsics with CopyUpperBits semantics must have op1 as target
+                assert(containedOpNum != 1 || !copiesUpperBits);
+
+                if (containedOpNum == 1)
+                {
+                    // https://github.com/dotnet/runtime/issues/62215
+                    // resultOpNum might change between lowering and lsra, comment out assertion for now.
+                    // assert(containedOpNum != resultOpNum);
+                    // resultOpNum is 3 or 0: op3/? = ([op1] * op2) + op3
+                    std::swap(emitOp1, emitOp3);
+
+                    if (resultOpNum == 2)
                     {
-                        srcCount += BuildDelayFreeUses(op2, op1);
+                        // op2 = ([op1] * op2) + op3
+                        std::swap(emitOp2, emitOp3);
                     }
-                    else
+                }
+                else if (containedOpNum == 3)
+                {
+                    // assert(containedOpNum != resultOpNum);
+                    if (resultOpNum == 2 && !copiesUpperBits)
                     {
-                        tgtPrefUse2 = BuildUse(op2);
-                        srcCount += 1;
+                        // op2 = (op1 * op2) + [op3]
+                        std::swap(emitOp1, emitOp2);
                     }
+                    // else: op1/? = (op1 * op2) + [op3]
+                }
+                else if (containedOpNum == 2)
+                {
+                    // assert(containedOpNum != resultOpNum);
 
-                    srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
+                    // op1/? = (op1 * [op2]) + op3
+                    std::swap(emitOp2, emitOp3);
+                    if (resultOpNum == 3 && !copiesUpperBits)
+                    {
+                        // op3 = (op1 * [op2]) + op3
+                        std::swap(emitOp1, emitOp2);
+                    }
                 }
+                else
+                {
+                    // containedOpNum == 0
+                    // no extra work when resultOpNum is 0 or 1
+                    if (resultOpNum == 2)
+                    {
+                        std::swap(emitOp1, emitOp2);
+                    }
+                    else if (resultOpNum == 3)
+                    {
+                        std::swap(emitOp1, emitOp3);
+                    }
+                }
+                tgtPrefUse = BuildUse(emitOp1);
+
+                srcCount += 1;
+                srcCount += BuildDelayFreeUses(emitOp2, emitOp1);
+                srcCount += emitOp3->isContained() ? BuildOperandUses(emitOp3) : BuildDelayFreeUses(emitOp3, emitOp1);
 
                 buildUses = false;
                 break;