From fa8bff0cd1ad28b78a8910ebb1be077ef010f91f Mon Sep 17 00:00:00 2001 From: Sam Parker Date: Fri, 5 Jun 2020 08:42:03 +0100 Subject: [PATCH] [CostModel] Unify getArithmeticInstrCost Add the remaining arithmetic opcodes into the generic implementation of getUserCost and then call this from getInstructionThroughput. Most of the backends have been modified to return the base implementation for cost kinds other RecipThroughput. The outlier here is AMDGPU which already uses getArithmeticInstrCost for all the cost kinds. This change means that most of the opcodes can be removed from that backends implementation of getUserCost. Differential Revision: https://reviews.llvm.org/D80992 --- .../llvm/Analysis/TargetTransformInfoImpl.h | 45 +++++++++++++++++++--- llvm/include/llvm/CodeGen/BasicTTIImpl.h | 7 ++++ llvm/lib/Analysis/TargetTransformInfo.cpp | 12 +----- .../Target/AArch64/AArch64TargetTransformInfo.cpp | 6 +++ .../Target/AMDGPU/AMDGPUTargetTransformInfo.cpp | 28 +++----------- llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp | 6 +++ .../Target/Hexagon/HexagonTargetTransformInfo.cpp | 6 +++ llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp | 5 +++ .../Target/SystemZ/SystemZTargetTransformInfo.cpp | 6 +++ llvm/lib/Target/X86/X86TargetTransformInfo.cpp | 5 +++ 10 files changed, 88 insertions(+), 38 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 9408485e..96153be 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -376,6 +376,20 @@ public: TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI = nullptr) { + // FIXME: A number of transformation tests seem to require these values + // which seems a little odd for how arbitary there are. + switch (Opcode) { + default: + break; + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::UDiv: + case Instruction::URem: + // FIXME: Unlikely to be true for CodeSize. + return TTI::TCC_Expensive; + } return 1; } @@ -830,14 +844,33 @@ public: GEP->getPointerOperand(), Operands.drop_front()); } - case Instruction::FDiv: - case Instruction::FRem: - case Instruction::SDiv: - case Instruction::SRem: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: case Instruction::URem: - // FIXME: Unlikely to be true for CodeSize. - return TTI::TCC_Expensive; + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + TargetTransformInfo::OperandValueKind Op1VK, Op2VK; + TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; + Op1VK = TTI::getOperandInfo(U->getOperand(0), Op1VP); + Op2VK = TTI::getOperandInfo(U->getOperand(1), Op2VP); + SmallVector Operands(U->operand_values()); + return TargetTTI->getArithmeticInstrCost(Opcode, Ty, CostKind, + Op1VK, Op2VK, + Op1VP, Op2VP, Operands, I); + } case Instruction::IntToPtr: case Instruction::PtrToInt: case Instruction::SIToFP: diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 361d981..3880bf6 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -608,6 +608,13 @@ public: int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, + Opd1Info, Opd2Info, + Opd1PropInfo, Opd2PropInfo, + Args, CxtI); + std::pair LT = TLI->getTypeLegalizationCost(DL, Ty); bool IsFloat = Ty->isFPOrFPVectorTy(); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index ed5db33..7325597 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1275,16 +1275,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { case Instruction::AShr: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { - TargetTransformInfo::OperandValueKind Op1VK, Op2VK; - TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; - Op1VK = getOperandInfo(I->getOperand(0), Op1VP); - Op2VK = getOperandInfo(I->getOperand(1), Op2VP); - SmallVector Operands(I->operand_values()); - return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind, - Op1VK, Op2VK, - Op1VP, Op2VP, Operands, I); - } + case Instruction::Xor: + return getUserCost(I, CostKind); case Instruction::FNeg: { TargetTransformInfo::OperandValueKind Op1VK, Op2VK; TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index cc5157f..1ec88b3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -500,6 +500,12 @@ int AArch64TTIImpl::getArithmeticInstrCost( TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, + Opd2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); + // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, Ty); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 085ba47..8086410 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -437,8 +437,11 @@ int GCNTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty, const Instruction *CxtI) { EVT OrigTy = TLI->getValueType(DL, Ty); if (!OrigTy.isSimple()) { - return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, - Opd2Info, + // FIXME: We're having to query the throughput cost so that the basic + // implementation tries to generate legalize and scalarization costs. Maybe + // we could hoist the scalarization code here? + return BaseT::getArithmeticInstrCost(Opcode, Ty, TTI::TCK_RecipThroughput, + Opd1Info, Opd2Info, Opd1PropInfo, Opd2PropInfo); } @@ -1036,29 +1039,10 @@ GCNTTIImpl::getUserCost(const User *U, ArrayRef Operands, return getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, 0, nullptr); } - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - case Instruction::FNeg: { + case Instruction::FNeg: return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind, TTI::OK_AnyValue, TTI::OK_AnyValue, TTI::OP_None, TTI::OP_None, Operands, I); - } default: break; } diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index 6b4899f..fb2c4c6 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -759,6 +759,12 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, + Op2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); + int ISDOpcode = TLI->InstructionOpcodeToISD(Opcode); std::pair LT = TLI->getTypeLegalizationCost(DL, Ty); diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index b70386f..cc1aaf9 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -250,6 +250,12 @@ unsigned HexagonTTIImpl::getArithmeticInstrCost( TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, + Opd2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); + if (Ty->isVectorTy()) { std::pair LT = TLI.getTypeLegalizationCost(DL, Ty); if (LT.second.isFloatingPoint()) diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp index bc88976..5ea5b2e 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -740,6 +740,11 @@ int PPCTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty, ArrayRef Args, const Instruction *CxtI) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, + Op2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); // Fallback to the default implementation. int Cost = BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp index f7a4c76..cebc3fd 100644 --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -374,6 +374,12 @@ int SystemZTTIImpl::getArithmeticInstrCost( TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, + Op2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); + // TODO: return a good value for BB-VECTORIZER that includes the // immediate loads, which we do not want to count for the loop // vectorizer, since they are hopefully hoisted out of the loop. This diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 964ac06..c9ebc5e 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -177,6 +177,11 @@ int X86TTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + // TODO: Handle more cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, + Op2Info, Opd1PropInfo, + Opd2PropInfo, Args, CxtI); // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, Ty); -- 2.7.4