[X86][AArch64][WebAsm][RISCV] Query operand properties instead of using enums directl...
authorPhilip Reames <preames@rivosinc.com>
Mon, 22 Aug 2022 19:03:36 +0000 (12:03 -0700)
committerPhilip Reames <listmail@philipreames.com>
Mon, 22 Aug 2022 20:37:59 +0000 (13:37 -0700)
This is part of an ongoing transition to use OperandValueInfo which combines OperandValueKind and OperandValueProperties.  This change adds some accessor methods and uses them to simplify backend code.  The primary motivation of doing so is removing uses of the parameters so that an upcoming api change is less error prone.

llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.cpp

index 5882972..5a70120 100644 (file)
@@ -906,6 +906,16 @@ public:
   struct OperandValueInfo {
     OperandValueKind Kind = OK_AnyValue;
     OperandValueProperties Properties = OP_None;
+
+    bool isConstant() const {
+      return Kind == OK_UniformConstantValue || Kind == OK_NonUniformConstantValue;
+    }
+    bool isUniform() const {
+      return Kind == OK_UniformConstantValue || Kind == OK_UniformValue;
+    }
+    bool isPowerOf2() const {
+      return Properties == OP_PowerOf2;
+    }
   };
 
   /// \return the number of registers in the target-provided register class.
index f9a977c..66c0111 100644 (file)
@@ -1982,6 +1982,9 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
     TTI::OperandValueProperties Opd1PropInfo,
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
+
+  const TTI::OperandValueInfo Op2Info = {Opd2Info, Opd2PropInfo};
+
   // TODO: Handle more cost kinds.
   if (CostKind != TTI::TCK_RecipThroughput)
     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
@@ -1997,8 +2000,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
                                          Opd2Info, Opd1PropInfo, Opd2PropInfo);
   case ISD::SDIV:
-    if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
-        Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
+    if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
       // On AArch64, scalar signed division by constants power-of-two are
       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
       // The OperandValue properties many not be same as that of previous
@@ -2019,7 +2021,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
     }
     [[fallthrough]];
   case ISD::UDIV: {
-    if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
+    if (Op2Info.isConstant() && Op2Info.isUniform()) {
       auto VT = TLI->getValueType(DL, Ty);
       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
         // Vector signed division by constant are expanded to the
index a3d2dfe..703ad9f 100644 (file)
@@ -514,11 +514,9 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
 }
 
 InstructionCost RISCVTTIImpl::getVectorImmCost(VectorType *VecTy,
-                                               TTI::OperandValueKind OpInfo,
-                                               TTI::OperandValueProperties PropInfo,
+                                               TTI::OperandValueInfo OpInfo,
                                                TTI::TargetCostKind CostKind) {
-  assert((OpInfo == TTI::OK_UniformConstantValue ||
-          OpInfo == TTI::OK_NonUniformConstantValue) && "non constant operand?");
+  assert(OpInfo.isConstant() && "non constant operand?");
   APInt PseudoAddr = APInt::getAllOnes(DL.getPointerSizeInBits());
   // Add a cost of address load + the cost of the vector load.
   return RISCVMatInt::getIntMatCost(PseudoAddr, DL.getPointerSizeInBits(),
@@ -532,16 +530,14 @@ InstructionCost RISCVTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
                                               MaybeAlign Alignment,
                                               unsigned AddressSpace,
                                               TTI::TargetCostKind CostKind,
-                                              TTI::OperandValueKind OpdInfo,
+                                              TTI::OperandValueKind OpdKind,
                                               const Instruction *I) {
+  const TTI::OperandValueInfo OpInfo = {OpdKind, TTI::OP_None};
   InstructionCost Cost = 0;
-  if (Opcode == Instruction::Store && isa<VectorType>(Src) &&
-      (OpdInfo == TTI::OK_UniformConstantValue ||
-       OpdInfo == TTI::OK_NonUniformConstantValue)) {
-    Cost += getVectorImmCost(cast<VectorType>(Src), OpdInfo, TTI::OP_None, CostKind);
-  }
+  if (Opcode == Instruction::Store && isa<VectorType>(Src) && OpInfo.isConstant())
+    Cost += getVectorImmCost(cast<VectorType>(Src), OpInfo, CostKind);
   return Cost + BaseT::getMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
-                                       CostKind, OpdInfo, I);
+                                       CostKind, OpInfo.Kind, I);
 }
 
 void RISCVTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
index 0d424b2..43be1cf 100644 (file)
@@ -54,8 +54,7 @@ public:
   /// Return the cost of materializing a vector immediate, assuming it does
   /// not get folded into the using instruction(s).
   InstructionCost getVectorImmCost(VectorType *VecTy,
-                                   TTI::OperandValueKind OpInfo,
-                                   TTI::OperandValueProperties PropInfo,
+                                   TTI::OperandValueInfo OpInfo,
                                    TTI::TargetCostKind CostKind);
 
   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
index c1cd58d..8df77a8 100644 (file)
@@ -57,6 +57,8 @@ InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
 
+  const TTI::OperandValueInfo Op2Info = {Opd2Info, Opd2PropInfo};
+
   InstructionCost Cost =
       BasicTTIImplBase<WebAssemblyTTIImpl>::getArithmeticInstrCost(
           Opcode, Ty, CostKind, Opd1Info, Opd2Info, Opd1PropInfo, Opd2PropInfo);
@@ -69,8 +71,7 @@ InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
       // SIMD128's shifts currently only accept a scalar shift count. For each
       // element, we'll need to extract, op, insert. The following is a rough
       // approximation.
-      if (Opd2Info != TTI::OK_UniformValue &&
-          Opd2Info != TTI::OK_UniformConstantValue)
+      if (!Op2Info.isUniform())
         Cost =
             cast<FixedVectorType>(VTy)->getNumElements() *
             (TargetTransformInfo::TCC_Basic +
index 74885eb..b3d88b8 100644 (file)
@@ -180,6 +180,9 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     TTI::OperandValueProperties Opd1PropInfo,
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
+
+  const TTI::OperandValueInfo Op2Info = {Op2Kind, Opd2PropInfo};
+
   // vXi8 multiplications are always promoted to vXi16.
   if (Opcode == Instruction::Mul && Ty->isVectorTy() &&
       Ty->getScalarSizeInBits() == 8) {
@@ -232,10 +235,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   }
 
   // Vector multiply by pow2 will be simplified to shifts.
-  if (ISD == ISD::MUL &&
-      (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2)
+  if (ISD == ISD::MUL && Op2Info.isConstant() && Op2Info.isPowerOf2())
     return getArithmeticInstrCost(Instruction::Shl, Ty, CostKind, Op1Kind,
                                   Op2Kind, TargetTransformInfo::OP_None,
                                   TargetTransformInfo::OP_None);
@@ -245,9 +245,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   // The OperandValue properties may not be the same as that of the previous
   // operation; conservatively assume OP_None.
   if ((ISD == ISD::SDIV || ISD == ISD::SREM) &&
-      (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
+      Op2Info.isConstant() && Op2Info.isPowerOf2()) {
     InstructionCost Cost =
         2 * getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, Op1Kind,
                                    Op2Kind, TargetTransformInfo::OP_None,
@@ -272,9 +270,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
 
   // Vector unsigned division/remainder will be simplified to shifts/masks.
   if ((ISD == ISD::UDIV || ISD == ISD::UREM) &&
-      (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
+      Op2Info.isConstant() && Op2Info.isPowerOf2()) {
     if (ISD == ISD::UDIV)
       return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind, Op1Kind,
                                     Op2Kind, TargetTransformInfo::OP_None,
@@ -372,8 +368,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::SRA,  MVT::v64i8,   4 }, // psrlw, pand, pxor, psubb.
   };
 
-  if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue &&
-      ST->hasBWI()) {
+  if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasBWI()) {
     if (const auto *Entry = CostTableLookup(AVX512BWUniformConstCostTable, ISD,
                                             LT.second))
       return LT.first * Entry->Cost;
@@ -394,8 +389,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v16i32,  7 }, // pmuludq+mul+sub sequence
   };
 
-  if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue &&
-      ST->hasAVX512()) {
+  if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasAVX512()) {
     if (const auto *Entry = CostTableLookup(AVX512UniformConstCostTable, ISD,
                                             LT.second))
       return LT.first * Entry->Cost;
@@ -414,8 +408,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v8i32,   7 }, // pmuludq+mul+sub sequence
   };
 
-  if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue &&
-      ST->hasAVX2()) {
+  if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasAVX2()) {
     if (const auto *Entry = CostTableLookup(AVX2UniformConstCostTable, ISD,
                                             LT.second))
       return LT.first * Entry->Cost;
@@ -441,7 +434,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   };
 
   // XOP has faster vXi8 shifts.
-  if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue &&
+  if (Op2Info.isUniform() && Op2Info.isConstant() &&
       ST->hasSSE2() && !ST->hasXOP()) {
     if (const auto *Entry =
             CostTableLookup(SSE2UniformConstCostTable, ISD, LT.second))
@@ -459,9 +452,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v32i16,  8 }, // vpmulhuw+mul+sub sequence
   };
 
-  if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      ST->hasBWI()) {
+  if (Op2Info.isConstant() && ST->hasBWI()) {
     if (const auto *Entry =
             CostTableLookup(AVX512BWConstCostTable, ISD, LT.second))
       return LT.first * Entry->Cost;
@@ -482,9 +473,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v32i16, 16 }, // 2*vpmulhuw+mul+sub sequence
   };
 
-  if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      ST->hasAVX512()) {
+  if (Op2Info.isConstant() && ST->hasAVX512()) {
     if (const auto *Entry =
             CostTableLookup(AVX512ConstCostTable, ISD, LT.second))
       return LT.first * Entry->Cost;
@@ -505,9 +494,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v8i32,  19 }, // vpmuludq+mul+sub sequence
   };
 
-  if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      ST->hasAVX2()) {
+  if (Op2Info.isConstant() && ST->hasAVX2()) {
     if (const auto *Entry = CostTableLookup(AVX2ConstCostTable, ISD, LT.second))
       return LT.first * Entry->Cost;
   }
@@ -539,9 +526,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::UREM, MVT::v4i32,    20 }, // pmuludq+mul+sub sequence
   };
 
-  if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-       Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) &&
-      ST->hasSSE2()) {
+  if (Op2Info.isConstant() && ST->hasSSE2()) {
     // pmuldq sequence.
     if (ISD == ISD::SDIV && LT.second == MVT::v8i32 && ST->hasAVX())
       return LT.first * 32;
@@ -598,9 +583,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::SRL,  MVT::v4i64,  1 }, // psrlq
   };
 
-  if (ST->hasAVX2() &&
-      ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) ||
-       (Op2Kind == TargetTransformInfo::OK_UniformValue))) {
+  if (ST->hasAVX2() && Op2Info.isUniform()) {
     if (const auto *Entry =
             CostTableLookup(AVX2UniformCostTable, ISD, LT.second))
       return LT.first * Entry->Cost;
@@ -620,9 +603,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::SRA,  MVT::v4i32,  1 }, // psrad.
   };
 
-  if (ST->hasSSE2() &&
-      ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) ||
-       (Op2Kind == TargetTransformInfo::OK_UniformValue))) {
+  if (ST->hasSSE2() && Op2Info.isUniform()) {
     if (const auto *Entry =
             CostTableLookup(SSE2UniformCostTable, ISD, LT.second))
       return LT.first * Entry->Cost;
@@ -717,9 +698,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   };
 
   if (ST->hasAVX512()) {
-    if (ISD == ISD::SHL && LT.second == MVT::v32i16 &&
-        (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-         Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue))
+    if (ISD == ISD::SHL && LT.second == MVT::v32i16 && Op2Info.isConstant())
       // On AVX512, a packed v32i16 shift left by a constant build_vector
       // is lowered into a vector multiply (vpmullw).
       return getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
@@ -731,8 +710,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
   // Look for AVX2 lowering tricks (XOP is always better at v4i32 shifts).
   if (ST->hasAVX2() && !(ST->hasXOP() && LT.second == MVT::v4i32)) {
     if (ISD == ISD::SHL && LT.second == MVT::v16i16 &&
-        (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-         Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue))
+        Op2Info.isConstant())
       // On AVX2, a packed v16i16 shift left by a constant build_vector
       // is lowered into a vector multiply (vpmullw).
       return getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
@@ -778,9 +756,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     // If the right shift is constant then we'll fold the negation so
     // it's as cheap as a left shift.
     int ShiftISD = ISD;
-    if ((ShiftISD == ISD::SRL || ShiftISD == ISD::SRA) &&
-        (Op2Kind == TargetTransformInfo::OK_UniformConstantValue ||
-         Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue))
+    if ((ShiftISD == ISD::SRL || ShiftISD == ISD::SRA) && Op2Info.isConstant())
       ShiftISD = ISD::SHL;
     if (const auto *Entry =
             CostTableLookup(XOPShiftCostTable, ShiftISD, LT.second))
@@ -803,9 +779,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost(
     { ISD::SRA,  MVT::v4i64,  8+2 }, // 2*(2*psrad + shuffle) + split.
   };
 
-  if (ST->hasSSE2() &&
-      ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) ||
-       (Op2Kind == TargetTransformInfo::OK_UniformValue))) {
+  if (ST->hasSSE2() && Op2Info.isUniform()) {
 
     // Handle AVX2 uniform v4i64 ISD::SRA, it's not worth a table.
     if (ISD == ISD::SRA && LT.second == MVT::v4i64 && ST->hasAVX2())
@@ -4069,8 +4043,10 @@ InstructionCost X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
                                             MaybeAlign Alignment,
                                             unsigned AddressSpace,
                                             TTI::TargetCostKind CostKind,
-                                            TTI::OperandValueKind OpdInfo,
+                                            TTI::OperandValueKind OpdKind,
                                             const Instruction *I) {
+  const TTI::OperandValueInfo OpInfo = {OpdKind, TTI::OP_None};
+
   // TODO: Handle other cost kinds.
   if (CostKind != TTI::TCK_RecipThroughput) {
     if (auto *SI = dyn_cast_or_null<StoreInst>(I)) {
@@ -4099,9 +4075,7 @@ InstructionCost X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
   InstructionCost Cost = 0;
 
   // Add a cost for constant load to vector.
-  if (Opcode == Instruction::Store &&
-      (OpdInfo == TTI::OK_UniformConstantValue ||
-       OpdInfo == TTI::OK_NonUniformConstantValue))
+  if (Opcode == Instruction::Store && OpInfo.isConstant())
     Cost += getMemoryOpCost(Instruction::Load, Src, DL.getABITypeAlign(Src),
                             /*AddressSpace=*/0, CostKind);