[CostModel][X86] getArithmeticInstrCost - move SLM reduceVMULWidth cost handling...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 26 Aug 2022 14:09:50 +0000 (15:09 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 26 Aug 2022 15:14:12 +0000 (16:14 +0100)
This is still SLM specific atm, but converting this to more closely match the codegen from reduceVMULWidth should be straightforward

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

index 4b68fd0..90b563c 100644 (file)
@@ -254,6 +254,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     unsigned Op1MinSize = BaseT::minRequiredElementSize(Args[0], Op1Signed);
     unsigned Op2MinSize = BaseT::minRequiredElementSize(Args[1], Op2Signed);
     unsigned OpMinSize = std::max(Op1MinSize, Op2MinSize);
+    bool SignedMode = Op1Signed || Op2Signed;
 
     // If both are representable as i15 and at least one is constant,
     // zero-extended, or sign-extended from vXi16 (or less pre-SSE41) then we
@@ -275,6 +276,20 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
         LT.second =
             MVT::getVectorVT(MVT::i16, 2 * LT.second.getVectorNumElements());
     }
+
+    // Check if the vXi32 operands can be shrunk into a smaller datatype.
+    // This should match the codegen from reduceVMULWidth.
+    // TODO: Make this generic (!ST->SSE41 || ST->isPMULLDSlow()).
+    if (ST->useSLMArithCosts() && LT.second == MVT::v4i32) {
+      if (OpMinSize <= 7)
+        return LT.first * 3; // pmullw/sext
+      if (!SignedMode && OpMinSize <= 8)
+        return LT.first * 3; // pmullw/zext
+      if (OpMinSize <= 15)
+        return LT.first * 5; // pmullw/pmulhw/pshuf
+      if (!SignedMode && OpMinSize <= 16)
+        return LT.first * 5; // pmullw/pmulhw/pshuf
+    }
   }
 
   // Vector multiply by pow2 will be simplified to shifts.
@@ -372,32 +387,10 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::SUB,   MVT::v2i64, {  4 } },
   };
 
-  if (ST->useSLMArithCosts()) {
-    if (Args.size() == 2 && ISD == ISD::MUL && LT.second == MVT::v4i32) {
-      // Check if the operands can be shrinked into a smaller datatype.
-      // TODO: Merge this into generiic vXi32 MUL patterns above.
-      bool Op1Signed = false;
-      unsigned Op1MinSize = BaseT::minRequiredElementSize(Args[0], Op1Signed);
-      bool Op2Signed = false;
-      unsigned Op2MinSize = BaseT::minRequiredElementSize(Args[1], Op2Signed);
-
-      bool SignedMode = Op1Signed || Op2Signed;
-      unsigned OpMinSize = std::max(Op1MinSize, Op2MinSize);
-
-      if (OpMinSize <= 7)
-        return LT.first * 3; // pmullw/sext
-      if (!SignedMode && OpMinSize <= 8)
-        return LT.first * 3; // pmullw/zext
-      if (OpMinSize <= 15)
-        return LT.first * 5; // pmullw/pmulhw/pshuf
-      if (!SignedMode && OpMinSize <= 16)
-        return LT.first * 5; // pmullw/pmulhw/pshuf
-    }
-
+  if (ST->useSLMArithCosts())
     if (const auto *Entry = CostTableLookup(SLMCostTable, ISD, LT.second))
       if (auto KindCost = Entry->Cost[CostKind])
         return LT.first * KindCost.value();
-  }
 
   static const CostKindTblEntry AVX512BWUniformConstCostTable[] = {
     { ISD::SHL,  MVT::v64i8, { 2 } }, // psllw + pand.