[TTI][X86] getArithmeticInstrCost - move opcode canonicalization before all target...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 30 Aug 2021 11:24:59 +0000 (12:24 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 30 Aug 2021 11:24:59 +0000 (12:24 +0100)
The GLM/SLM special cases still get tested first but after the the MUL/DIV/REM pattern detection - this will be necessary for when we make the SLM vXi32 MUL canonicalization generic to improve PMULLW/PMULHW/PMADDDW cost support etc.

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

index 532bc33..8c0c53e 100644 (file)
@@ -206,6 +206,50 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  if ((ISD == ISD::SDIV || ISD == ISD::SREM || ISD == ISD::UDIV ||
+       ISD == ISD::UREM) &&
+      (Op2Info == TargetTransformInfo::OK_UniformConstantValue ||
+       Op2Info == TargetTransformInfo::OK_NonUniformConstantValue) &&
+      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
+    if (ISD == ISD::SDIV || ISD == ISD::SREM) {
+      // On X86, vector signed division by constants power-of-two are
+      // normally expanded to the sequence SRA + SRL + ADD + SRA.
+      // The OperandValue properties may not be the same as that of the previous
+      // operation; conservatively assume OP_None.
+      InstructionCost Cost =
+          2 * getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, Op1Info,
+                                     Op2Info, TargetTransformInfo::OP_None,
+                                     TargetTransformInfo::OP_None);
+      Cost += getArithmeticInstrCost(Instruction::LShr, Ty, CostKind, Op1Info,
+                                     Op2Info, TargetTransformInfo::OP_None,
+                                     TargetTransformInfo::OP_None);
+      Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind, Op1Info,
+                                     Op2Info, TargetTransformInfo::OP_None,
+                                     TargetTransformInfo::OP_None);
+
+      if (ISD == ISD::SREM) {
+        // For SREM: (X % C) is the equivalent of (X - (X/C)*C)
+        Cost += getArithmeticInstrCost(Instruction::Mul, Ty, CostKind, Op1Info,
+                                       Op2Info);
+        Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind, Op1Info,
+                                       Op2Info);
+      }
+
+      return Cost;
+    }
+
+    // Vector unsigned division/remainder will be simplified to shifts/masks.
+    if (ISD == ISD::UDIV)
+      return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind, Op1Info,
+                                    Op2Info, TargetTransformInfo::OP_None,
+                                    TargetTransformInfo::OP_None);
+
+    else // UREM
+      return getArithmeticInstrCost(Instruction::And, Ty, CostKind, Op1Info,
+                                    Op2Info, TargetTransformInfo::OP_None,
+                                    TargetTransformInfo::OP_None);
+  }
+
   static const CostTblEntry GLMCostTable[] = {
     { ISD::FDIV,  MVT::f32,   18 }, // divss
     { ISD::FDIV,  MVT::v4f32, 35 }, // divps
@@ -268,54 +312,6 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     }
   }
 
-  if ((ISD == ISD::SDIV || ISD == ISD::SREM || ISD == ISD::UDIV ||
-       ISD == ISD::UREM) &&
-      (Op2Info == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Info == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
-    if (ISD == ISD::SDIV || ISD == ISD::SREM) {
-      // On X86, vector signed division by constants power-of-two are
-      // normally expanded to the sequence SRA + SRL + ADD + SRA.
-      // The OperandValue properties may not be the same as that of the previous
-      // operation; conservatively assume OP_None.
-      InstructionCost Cost =
-          2 * getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, Op1Info,
-                                     Op2Info, TargetTransformInfo::OP_None,
-                                     TargetTransformInfo::OP_None);
-      Cost += getArithmeticInstrCost(Instruction::LShr, Ty, CostKind, Op1Info,
-                                     Op2Info,
-                                     TargetTransformInfo::OP_None,
-                                     TargetTransformInfo::OP_None);
-      Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind, Op1Info,
-                                     Op2Info,
-                                     TargetTransformInfo::OP_None,
-                                     TargetTransformInfo::OP_None);
-
-      if (ISD == ISD::SREM) {
-        // For SREM: (X % C) is the equivalent of (X - (X/C)*C)
-        Cost += getArithmeticInstrCost(Instruction::Mul, Ty, CostKind, Op1Info,
-                                       Op2Info);
-        Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind, Op1Info,
-                                       Op2Info);
-      }
-
-      return Cost;
-    }
-
-    // Vector unsigned division/remainder will be simplified to shifts/masks.
-    if (ISD == ISD::UDIV)
-      return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
-                                    Op1Info, Op2Info,
-                                    TargetTransformInfo::OP_None,
-                                    TargetTransformInfo::OP_None);
-
-    else // UREM
-      return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
-                                    Op1Info, Op2Info,
-                                    TargetTransformInfo::OP_None,
-                                    TargetTransformInfo::OP_None);
-  }
-
   static const CostTblEntry AVX512BWUniformConstCostTable[] = {
     { ISD::SHL,  MVT::v64i8,   2 }, // psllw + pand.
     { ISD::SRL,  MVT::v64i8,   2 }, // psrlw + pand.