[Analysis] TTI: Add CastContextHint for getCastInstrCost
authorDavid Green <david.green@arm.com>
Wed, 29 Jul 2020 12:32:53 +0000 (13:32 +0100)
committerDavid Green <david.green@arm.com>
Wed, 29 Jul 2020 12:32:53 +0000 (13:32 +0100)
Currently, getCastInstrCost has limited information about the cast it's
rating, often just the opcode and types.  Sometimes there is a context
instruction as well, but it isn't trustworthy: for instance, when the
vectorizer is rating a plan, it calls getCastInstrCost with the old
instructions when, in fact, it's trying to evaluate the cost of the
instruction post-vectorization.  Thus, the current system can get the
cost of certain casts incorrect as the correct cost can vary greatly
based on the context in which it's used.

For example, if the vectorizer queries getCastInstrCost to evaluate the
cost of a sext(load) with tail predication enabled, getCastInstrCost
will think it's free most of the time, but it's not always free. On ARM
MVE, a VLD2 group cannot be extended like a normal VLDR can. Similar
situations can come up with how masked loads can be extended when being
split.

To fix that, this path adds a new parameter to getCastInstrCost to give
it a hint about the context of the cast. It adds a CastContextHint enum
which contains the type of the load/store being created by the
vectorizer - one for each of the types it can produce.

Original patch by Pierre van Houtryve

Differential Revision: https://reviews.llvm.org/D79162

20 files changed:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
llvm/lib/Target/ARM/ARMTargetTransformInfo.h
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
llvm/lib/Target/X86/X86TargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.h
llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

index e1426e3..092c974 100644 (file)
@@ -1021,10 +1021,47 @@ public:
   int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index = 0,
                      VectorType *SubTp = nullptr) const;
 
+  /// Represents a hint about the context in which a cast is used.
+  ///
+  /// For zext/sext, the context of the cast is the operand, which must be a
+  /// load of some kind. For trunc, the context is of the cast is the single
+  /// user of the instruction, which must be a store of some kind.
+  ///
+  /// This enum allows the vectorizer to give getCastInstrCost an idea of the
+  /// type of cast it's dealing with, as not every cast is equal. For instance,
+  /// the zext of a load may be free, but the zext of an interleaving load can
+  //// be (very) expensive!
+  ///
+  /// See \c getCastContextHint to compute a CastContextHint from a cast
+  /// Instruction*. Callers can use it if they don't need to override the
+  /// context and just want it to be calculated from the instruction.
+  ///
+  /// FIXME: This handles the types of load/store that the vectorizer can
+  /// produce, which are the cases where the context instruction is most
+  /// likely to be incorrect. There are other situations where that can happen
+  /// too, which might be handled here but in the long run a more general
+  /// solution of costing multiple instructions at the same times may be better.
+  enum class CastContextHint : uint8_t {
+    None,          ///< The cast is not used with a load/store of any kind.
+    Normal,        ///< The cast is used with a normal load/store.
+    Masked,        ///< The cast is used with a masked load/store.
+    GatherScatter, ///< The cast is used with a gather/scatter.
+    Interleave,    ///< The cast is used with an interleaved load/store.
+    Reversed,      ///< The cast is used with a reversed load/store.
+  };
+
+  /// Calculates a CastContextHint from \p I.
+  /// This should be used by callers of getCastInstrCost if they wish to
+  /// determine the context from some instruction.
+  /// \returns the CastContextHint for ZExt/SExt/Trunc, None if \p I is nullptr,
+  /// or if it's another type of cast.
+  static CastContextHint getCastContextHint(const Instruction *I);
+
   /// \return The expected cost of cast instructions, such as bitcast, trunc,
   /// zext, etc. If there is an existing instruction that holds Opcode, it
   /// may be passed in the 'I' parameter.
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                       TTI::CastContextHint CCH,
                        TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency,
                        const Instruction *I = nullptr) const;
 
@@ -1454,6 +1491,7 @@ public:
   virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index,
                              VectorType *SubTp) = 0;
   virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                               CastContextHint CCH,
                                TTI::TargetCostKind CostKind,
                                const Instruction *I) = 0;
   virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst,
@@ -1882,9 +1920,9 @@ public:
     return Impl.getShuffleCost(Kind, Tp, Index, SubTp);
   }
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I) override {
-    return Impl.getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   }
   int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy,
                                unsigned Index) override {
index 73e5ff6..4dc0a90 100644 (file)
@@ -423,6 +423,7 @@ public:
   }
 
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                            TTI::CastContextHint CCH,
                             TTI::TargetCostKind CostKind,
                             const Instruction *I) {
     switch (Opcode) {
@@ -915,7 +916,8 @@ public:
     case Instruction::SExt:
     case Instruction::ZExt:
     case Instruction::AddrSpaceCast:
-      return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
+      return TargetTTI->getCastInstrCost(
+          Opcode, Ty, OpTy, TTI::getCastContextHint(I), CostKind, I);
     case Instruction::Store: {
       auto *SI = cast<StoreInst>(U);
       Type *ValTy = U->getOperand(0)->getType();
index aabfb74..e9f6329 100644 (file)
@@ -716,9 +716,10 @@ public:
   }
 
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                            TTI::CastContextHint CCH,
                             TTI::TargetCostKind CostKind,
                             const Instruction *I = nullptr) {
-    if (BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I) == 0)
+    if (BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I) == 0)
       return 0;
 
     const TargetLoweringBase *TLI = getTLI();
@@ -756,15 +757,12 @@ public:
         return 0;
       LLVM_FALLTHROUGH;
     case Instruction::SExt:
-      if (!I)
-        break;
-
-      if (getTLI()->isExtFree(I))
+      if (I && getTLI()->isExtFree(I))
         return 0;
 
       // If this is a zext/sext of a load, return 0 if the corresponding
       // extending load exists on target.
-      if (I && isa<LoadInst>(I->getOperand(0))) {
+      if (CCH == TTI::CastContextHint::Normal) {
         EVT ExtVT = EVT::getEVT(Dst);
         EVT LoadVT = EVT::getEVT(Src);
         unsigned LType =
@@ -839,7 +837,7 @@ public:
         unsigned SplitCost =
             (!SplitSrc || !SplitDst) ? TTI->getVectorSplitCost() : 0;
         return SplitCost +
-               (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy,
+               (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, CCH,
                                           CostKind, I));
       }
 
@@ -847,7 +845,7 @@ public:
       // the operation will get scalarized.
       unsigned Num = cast<FixedVectorType>(DstVTy)->getNumElements();
       unsigned Cost = thisT()->getCastInstrCost(
-          Opcode, Dst->getScalarType(), Src->getScalarType(), CostKind, I);
+          Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind, I);
 
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values.
@@ -872,7 +870,7 @@ public:
     return thisT()->getVectorInstrCost(Instruction::ExtractElement, VecTy,
                                        Index) +
            thisT()->getCastInstrCost(Opcode, Dst, VecTy->getElementType(),
-                                     TTI::TCK_RecipThroughput);
+                                     TTI::CastContextHint::None, TTI::TCK_RecipThroughput);
   }
 
   unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {
@@ -1522,13 +1520,14 @@ public:
 
       unsigned ExtOp =
           IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
+      TTI::CastContextHint CCH = TTI::CastContextHint::None;
 
       unsigned Cost = 0;
-      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CostKind);
+      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CCH, CostKind);
       Cost +=
           thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
       Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, RetTy, ExtTy,
-                                            CostKind);
+                                            CCH, CostKind);
       Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, RetTy,
                                               CostKind, TTI::OK_AnyValue,
                                               TTI::OK_UniformConstantValue);
@@ -1587,13 +1586,14 @@ public:
 
       unsigned ExtOp =
           IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
+      TTI::CastContextHint CCH = TTI::CastContextHint::None;
 
       unsigned Cost = 0;
-      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CostKind);
+      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CCH, CostKind);
       Cost +=
           thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
       Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, MulTy, ExtTy,
-                                            CostKind);
+                                            CCH, CostKind);
       Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, MulTy,
                                               CostKind, TTI::OK_AnyValue,
                                               TTI::OK_UniformConstantValue);
index c9e702c..944e621 100644 (file)
@@ -730,12 +730,57 @@ int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, VectorType *Ty,
   return Cost;
 }
 
+TTI::CastContextHint
+TargetTransformInfo::getCastContextHint(const Instruction *I) {
+  if (!I)
+    return CastContextHint::None;
+
+  auto getLoadStoreKind = [](const Value *V, unsigned LdStOp, unsigned MaskedOp,
+                             unsigned GatScatOp) {
+    const Instruction *I = dyn_cast<Instruction>(V);
+    if (!I)
+      return CastContextHint::None;
+
+    if (I->getOpcode() == LdStOp)
+      return CastContextHint::Normal;
+
+    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+      if (II->getIntrinsicID() == MaskedOp)
+        return TTI::CastContextHint::Masked;
+      if (II->getIntrinsicID() == GatScatOp)
+        return TTI::CastContextHint::GatherScatter;
+    }
+
+    return TTI::CastContextHint::None;
+  };
+
+  switch (I->getOpcode()) {
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::FPExt:
+    return getLoadStoreKind(I->getOperand(0), Instruction::Load,
+                            Intrinsic::masked_load, Intrinsic::masked_gather);
+  case Instruction::Trunc:
+  case Instruction::FPTrunc:
+    if (I->hasOneUse())
+      return getLoadStoreKind(*I->user_begin(), Instruction::Store,
+                              Intrinsic::masked_store,
+                              Intrinsic::masked_scatter);
+    break;
+  default:
+    return CastContextHint::None;
+  }
+
+  return TTI::CastContextHint::None;
+}
+
 int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                          CastContextHint CCH,
                                           TTI::TargetCostKind CostKind,
                                           const Instruction *I) const {
   assert((I == nullptr || I->getOpcode() == Opcode) &&
          "Opcode should reflect passed instruction.");
-  int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
index cf6de79..5f5da63 100644 (file)
@@ -270,6 +270,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
 }
 
 int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                     TTI::CastContextHint CCH,
                                      TTI::TargetCostKind CostKind,
                                      const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -306,7 +307,8 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+    return AdjustCost(
+        BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 
   static const TypeConversionCostTblEntry
   ConversionTbl[] = {
@@ -410,7 +412,8 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                                  SrcTy.getSimpleVT()))
     return AdjustCost(Entry->Cost);
 
-  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+  return AdjustCost(
+      BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
@@ -442,12 +445,14 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
   // we may get the extension for free. If not, get the default cost for the
   // extend.
   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
-    return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+    return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                   CostKind);
 
   // The destination type should be larger than the element type. If not, get
   // the default cost for the extend.
   if (DstVT.getSizeInBits() < SrcVT.getSizeInBits())
-    return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+    return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                   CostKind);
 
   switch (Opcode) {
   default:
@@ -466,7 +471,8 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
   }
 
   // If we are unable to perform the extend for free, get the default cost.
-  return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+  return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                 CostKind);
 }
 
 unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
index 1f02968..5d1371f 100644 (file)
@@ -114,7 +114,7 @@ public:
   unsigned getMaxInterleaveFactor(unsigned VF);
 
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
 
   int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy,
index 25d0a4c..59f33e6 100644 (file)
@@ -301,6 +301,7 @@ int ARMTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Im
 }
 
 int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -317,7 +318,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+    return AdjustCost(
+        BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 
   // The extend of a load is free
   if (I && isa<LoadInst>(I->getOperand(0))) {
@@ -388,8 +390,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     };
     if (SrcTy.isVector() && ST->hasMVEIntegerOps()) {
       if (const auto *Entry =
-              ConvertCostTableLookup(MVELoadConversionTbl, ISD, SrcTy.getSimpleVT(),
-                                     DstTy.getSimpleVT()))
+              ConvertCostTableLookup(MVELoadConversionTbl, ISD,
+                                     SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
         return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
     }
 
@@ -399,8 +401,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     };
     if (SrcTy.isVector() && ST->hasMVEFloatOps()) {
       if (const auto *Entry =
-              ConvertCostTableLookup(MVEFLoadConversionTbl, ISD, SrcTy.getSimpleVT(),
-                                     DstTy.getSimpleVT()))
+              ConvertCostTableLookup(MVEFLoadConversionTbl, ISD,
+                                     SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
         return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
     }
   }
@@ -672,7 +674,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                      ? ST->getMVEVectorCostFactor()
                      : 1;
   return AdjustCost(
-    BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+      BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
index 093dfbb..ac7d037 100644 (file)
@@ -210,7 +210,7 @@ public:
   }
 
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
 
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
index 80c8736..68efaf7 100644 (file)
@@ -270,7 +270,9 @@ unsigned HexagonTTIImpl::getArithmeticInstrCost(
 }
 
 unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
-      Type *SrcTy, TTI::TargetCostKind CostKind, const Instruction *I) {
+                                          Type *SrcTy, TTI::CastContextHint CCH,
+                                          TTI::TargetCostKind CostKind,
+                                          const Instruction *I) {
   if (SrcTy->isFPOrFPVectorTy() || DstTy->isFPOrFPVectorTy()) {
     unsigned SrcN = SrcTy->isFPOrFPVectorTy() ? getTypeNumElements(SrcTy) : 0;
     unsigned DstN = DstTy->isFPOrFPVectorTy() ? getTypeNumElements(DstTy) : 0;
index 5fe3974..07e59fb 100644 (file)
@@ -146,8 +146,9 @@ public:
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr);
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-            TTI::TargetCostKind CostKind,
-            const Instruction *I = nullptr);
+                            TTI::CastContextHint CCH,
+                            TTI::TargetCostKind CostKind,
+                            const Instruction *I = nullptr);
   unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
 
   unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {
index ee8842f..4171d4a 100644 (file)
@@ -879,11 +879,12 @@ int PPCTTIImpl::getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {
 }
 
 int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
 
-  int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src);
   // TODO: Allow non-throughput costs that aren't binary.
   if (CostKind != TTI::TCK_RecipThroughput)
index a3453b0..d9aab29 100644 (file)
@@ -106,7 +106,7 @@ public:
       const Instruction *CxtI = nullptr);
   int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
index 864200e..8758dde 100644 (file)
@@ -699,11 +699,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
 }
 
 int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                     TTI::CastContextHint CCH,
                                      TTI::TargetCostKind CostKind,
                                      const Instruction *I) {
   // FIXME: Can the logic below also be used for these cost kinds?
   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) {
-    int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
     return BaseCost == 0 ? BaseCost : 1;
   }
 
@@ -786,8 +787,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values. Base implementation does not
       // realize float->int gets scalarized.
-      unsigned ScalarCost = getCastInstrCost(Opcode, Dst->getScalarType(),
-                                             Src->getScalarType(), CostKind);
+      unsigned ScalarCost = getCastInstrCost(
+          Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind);
       unsigned TotCost = VF * ScalarCost;
       bool NeedsInserts = true, NeedsExtracts = true;
       // FP128 registers do not get inserted or extracted.
@@ -828,7 +829,7 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     }
   }
 
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 }
 
 // Scalar i8 / i16 operations will typically be made after first extending
index 7f8f7f6..1aa31ff 100644 (file)
@@ -93,7 +93,7 @@ public:
   unsigned getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
                                          const Instruction *I);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
                          TTI::TargetCostKind CostKind,
index 491078f..a484114 100644 (file)
@@ -1367,6 +1367,7 @@ int X86TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *BaseTp,
 }
 
 int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -1988,7 +1989,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
 
   // The function getSimpleVT only handles simple value types.
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind));
+    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind));
 
   MVT SimpleSrcTy = SrcTy.getSimpleVT();
   MVT SimpleDstTy = DstTy.getSimpleVT();
@@ -2049,7 +2050,8 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
       return AdjustCost(Entry->Cost);
   }
 
-  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+  return AdjustCost(
+      BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
index fb47190..8d2fa27 100644 (file)
@@ -130,7 +130,7 @@ public:
   int getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index,
                      VectorType *SubTp);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
                          TTI::TargetCostKind CostKind,
index dc2ad14..c344c6c 100644 (file)
@@ -2025,8 +2025,8 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,
 
       Type *SrcTy = CI->getOperand(0)->getType();
       Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy,
-                                   TargetTransformInfo::TCK_SizeAndLatency,
-                                   CI);
+                                   TTI::getCastContextHint(CI),
+                                   TargetTransformInfo::TCK_SizeAndLatency, CI);
 
     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {
       // Cost of the address calculation
index 3527ca4..c99d612 100644 (file)
@@ -2150,8 +2150,9 @@ bool SCEVExpander::isHighCostExpansionHelper(
       llvm_unreachable("There are no other cast types.");
     }
     const SCEV *Op = CastExpr->getOperand();
-    BudgetRemaining -= TTI.getCastInstrCost(Opcode, /*Dst=*/S->getType(),
-                                            /*Src=*/Op->getType(), CostKind);
+    BudgetRemaining -= TTI.getCastInstrCost(
+        Opcode, /*Dst=*/S->getType(),
+        /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind);
     Worklist.emplace_back(Op);
     return false; // Will answer upon next entry into this function.
   }
index d207ca2..4b51951 100644 (file)
@@ -6458,13 +6458,54 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
   case Instruction::Trunc:
   case Instruction::FPTrunc:
   case Instruction::BitCast: {
+    // Computes the CastContextHint from a Load/Store instruction.
+    auto ComputeCCH = [&](Instruction *I) -> TTI::CastContextHint {
+      assert((isa<LoadInst>(I) || isa<StoreInst>(I)) &&
+             "Expected a load or a store!");
+
+      if (VF == 1)
+        return TTI::CastContextHint::Normal;
+
+      switch (getWideningDecision(I, VF)) {
+      case LoopVectorizationCostModel::CM_GatherScatter:
+        return TTI::CastContextHint::GatherScatter;
+      case LoopVectorizationCostModel::CM_Interleave:
+        return TTI::CastContextHint::Interleave;
+      case LoopVectorizationCostModel::CM_Scalarize:
+      case LoopVectorizationCostModel::CM_Widen:
+        return Legal->isMaskRequired(I) ? TTI::CastContextHint::Masked
+                                        : TTI::CastContextHint::Normal;
+      case LoopVectorizationCostModel::CM_Widen_Reverse:
+        return TTI::CastContextHint::Reversed;
+      case LoopVectorizationCostModel::CM_Unknown:
+        llvm_unreachable("Instr did not go through cost modelling?");
+      }
+
+      llvm_unreachable("Unhandled case!");
+    };
+
+    unsigned Opcode = I->getOpcode();
+    TTI::CastContextHint CCH = TTI::CastContextHint::None;
+    // For Trunc, the context is the only user, which must be a StoreInst.
+    if (Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) {
+      if (I->hasOneUse())
+        if (StoreInst *Store = dyn_cast<StoreInst>(*I->user_begin()))
+          CCH = ComputeCCH(Store);
+    }
+    // For Z/Sext, the context is the operand, which must be a LoadInst.
+    else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
+             Opcode == Instruction::FPExt) {
+      if (LoadInst *Load = dyn_cast<LoadInst>(I->getOperand(0)))
+        CCH = ComputeCCH(Load);
+    }
+
     // We optimize the truncation of induction variables having constant
     // integer steps. The cost of these truncations is the same as the scalar
     // operation.
     if (isOptimizableIVTruncate(I, VF)) {
       auto *Trunc = cast<TruncInst>(I);
       return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(),
-                                  Trunc->getSrcTy(), CostKind, Trunc);
+                                  Trunc->getSrcTy(), CCH, CostKind, Trunc);
     }
 
     Type *SrcScalarTy = I->getOperand(0)->getType();
@@ -6477,12 +6518,11 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
       //
       // Calculate the modified src and dest types.
       Type *MinVecTy = VectorTy;
-      if (I->getOpcode() == Instruction::Trunc) {
+      if (Opcode == Instruction::Trunc) {
         SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy);
         VectorTy =
             largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
-      } else if (I->getOpcode() == Instruction::ZExt ||
-                 I->getOpcode() == Instruction::SExt) {
+      } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) {
         SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy);
         VectorTy =
             smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
@@ -6490,8 +6530,8 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
     }
 
     unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1;
-    return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy,
-                                    CostKind, I);
+    return N *
+           TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I);
   }
   case Instruction::Call: {
     bool NeedToScalarize;
index 5fb8ad5..d575b32 100644 (file)
@@ -3399,8 +3399,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
                   Ext->getOpcode(), Ext->getType(), VecTy, i);
               // Add back the cost of s|zext which is subtracted separately.
               DeadCost += TTI->getCastInstrCost(
-                  Ext->getOpcode(), Ext->getType(), E->getType(), CostKind,
-                  Ext);
+                  Ext->getOpcode(), Ext->getType(), E->getType(),
+                  TTI::getCastContextHint(Ext), CostKind, Ext);
               continue;
             }
           }
@@ -3424,8 +3424,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
     case Instruction::BitCast: {
       Type *SrcTy = VL0->getOperand(0)->getType();
       int ScalarEltCost =
-          TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, CostKind,
-                                VL0);
+          TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy,
+                                TTI::getCastContextHint(VL0), CostKind, VL0);
       if (NeedToShuffleReuses) {
         ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost;
       }
@@ -3437,9 +3437,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
       int VecCost = 0;
       // Check if the values are candidates to demote.
       if (!MinBWs.count(VL0) || VecTy != SrcVecTy) {
-        VecCost = ReuseShuffleCost +
-                  TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
-                                        CostKind, VL0);
+        VecCost =
+            ReuseShuffleCost +
+            TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
+                                  TTI::getCastContextHint(VL0), CostKind, VL0);
       }
       return VecCost - ScalarCost;
     }
@@ -3644,9 +3645,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
         auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size());
         auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size());
         VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty,
-                                        CostKind);
+                                        TTI::CastContextHint::None, CostKind);
         VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
-                                         CostKind);
+                                         TTI::CastContextHint::None, CostKind);
       }
       VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0);
       return ReuseShuffleCost + VecCost - ScalarCost;