From 37289615c01d08915626c58f825bc797c3a036e4 Mon Sep 17 00:00:00 2001 From: Sam Parker Date: Tue, 26 May 2020 14:28:34 +0100 Subject: [PATCH] [NFCI][CostModel] Unify getCmpSelInstrCost Add cases for icmp, fcmp and select into the switch statement of the generic getUserCost implementation with getInstructionThroughput then calling into it. The BasicTTI and backend implementations have be set to return a default value (1) when a cost other than throughput is being queried. Differential Revision: https://reviews.llvm.org/D80550 --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 11 +++++++++++ llvm/include/llvm/CodeGen/BasicTTIImpl.h | 4 ++++ llvm/lib/Analysis/TargetTransformInfo.cpp | 14 +++----------- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 3 +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp | 4 ++++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 2 +- llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp | 3 +++ llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp | 3 +++ llvm/lib/Target/X86/X86TargetTransformInfo.cpp | 4 ++++ 9 files changed, 36 insertions(+), 12 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index e74d01b..9408485e 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -865,6 +865,17 @@ public: LI->getPointerAddressSpace(), CostKind, I); } + case Instruction::Select: { + Type *CondTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy, + CostKind, I); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(), + CostKind, I); + } } // By default, just classify everything as 'basic'. return TTI::TCC_Basic; diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 8d83306..361d981 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -838,6 +838,10 @@ public: int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Selects on vectors are actually vector selects. if (ISD == ISD::SELECT) { assert(CondTy && "CondTy must exist"); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 4bcadfe..ed5db33 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1296,18 +1296,10 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { Op1VK, Op2VK, Op1VP, Op2VP, Operands, I); } - case Instruction::Select: { - const SelectInst *SI = cast(I); - Type *CondTy = SI->getCondition()->getType(); - return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, - CostKind, I); - } + case Instruction::Select: case Instruction::ICmp: - case Instruction::FCmp: { - Type *ValTy = I->getOperand(0)->getType(); - return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), - CostKind, I); - } + case Instruction::FCmp: + return getUserCost(I, CostKind); case Instruction::Store: case Instruction::Load: return getUserCost(I, CostKind); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 3030c94..cc5157f 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -620,6 +620,9 @@ int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); int ISD = TLI->InstructionOpcodeToISD(Opcode); // We don't lower some vector selects well that are wider than the register diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index 0709f7c..6b4899f 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -508,6 +508,10 @@ int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, int ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + int ISD = TLI->InstructionOpcodeToISD(Opcode); // On NEON a vector select gets lowered to vbsl. if (ST->hasNEON() && ValTy->isVectorTy() && ISD == ISD::SELECT) { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index e0eaf55..b70386f 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -236,7 +236,7 @@ unsigned HexagonTTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, unsigned HexagonTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { - if (ValTy->isVectorTy()) { + if (ValTy->isVectorTy() && CostKind == TTI::TCK_RecipThroughput) { std::pair LT = TLI.getTypeLegalizationCost(DL, ValTy); if (Opcode == Instruction::FCmp) return LT.first + FloatFactor * getTypeNumElements(ValTy); diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp index 8c77d74..bc88976 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -779,6 +779,9 @@ int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { int Cost = BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return Cost; return vectorCostAdjustment(Cost, Opcode, ValTy, nullptr); } diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp index ef26a91..f7a4c76 100644 --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -837,6 +837,9 @@ int SystemZTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind); + if (!ValTy->isVectorTy()) { switch (Opcode) { case Instruction::ICmp: { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 3980909..964ac06 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2051,6 +2051,10 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); -- 2.7.4