From e5dfb08fcb8040ebd39b2098ca74bc809cf8e914 Mon Sep 17 00:00:00 2001 From: Matthew Simpson Date: Wed, 27 Apr 2016 15:20:21 +0000 Subject: [PATCH] [TTI] Add hook for vector extract with extension This change adds a new hook for estimating the cost of vector extracts followed by zero- and sign-extensions. The motivating example for this change is the SMOV and UMOV instructions on AArch64. These instructions move data from vector to general purpose registers while performing the corresponding extension (sign-extend for SMOV and zero-extend for UMOV) at the same time. For these operations, TargetTransformInfo can assume the extensions are free and only report the cost of the vector extract. The SLP vectorizer has been updated to make use of the new hook. Differential Revision: http://reviews.llvm.org/D18523 llvm-svn: 267725 --- llvm/include/llvm/Analysis/TargetTransformInfo.h | 11 +++++ .../llvm/Analysis/TargetTransformInfoImpl.h | 5 ++ llvm/include/llvm/CodeGen/BasicTTIImpl.h | 8 ++++ llvm/lib/Analysis/TargetTransformInfo.cpp | 8 ++++ .../Target/AArch64/AArch64TargetTransformInfo.cpp | 55 ++++++++++++++++++++++ .../Target/AArch64/AArch64TargetTransformInfo.h | 3 ++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 7 +-- .../SLPVectorizer/AArch64/gather-reduce.ll | 33 ++++++------- 8 files changed, 108 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index e566e2a..bde7a7b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -476,6 +476,11 @@ public: /// zext, etc. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const; + /// \return The expected cost of a sign- or zero-extended vector extract. Use + /// -1 to indicate that there is no information about the index value. + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index = -1) const; + /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br. int getCFInstrCost(unsigned Opcode) const; @@ -662,6 +667,8 @@ public: virtual int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) = 0; + virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) = 0; virtual int getCFInstrCost(unsigned Opcode) = 0; virtual int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) = 0; @@ -855,6 +862,10 @@ public: int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) override { return Impl.getCastInstrCost(Opcode, Dst, Src); } + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index) override { + return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + } int getCFInstrCost(unsigned Opcode) override { return Impl.getCFInstrCost(Opcode); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 6ddfc02..c5cbf4e 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -293,6 +293,11 @@ public: unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return 1; } + unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) { + return 1; + } + unsigned getCFInstrCost(unsigned Opcode) { return 1; } unsigned getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) { diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index b2c7ef6..ded19c5 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -435,6 +435,14 @@ public: llvm_unreachable("Unhandled cast"); } + unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) { + return static_cast(this)->getVectorInstrCost( + Instruction::ExtractElement, VecTy, Index) + + static_cast(this)->getCastInstrCost(Opcode, Dst, + VecTy->getElementType()); + } + unsigned getCFInstrCost(unsigned Opcode) { // Branches are assumed to be predicted. return 0; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index c8f116a..c5793ad 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -267,6 +267,14 @@ int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, return Cost; } +int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) const { + int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { int Cost = TTIImpl->getCFInstrCost(Opcode); assert(Cost >= 0 && "TTI should not produce negative costs!"); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 87f96f8..8e832ff 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -291,6 +291,61 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return BaseT::getCastInstrCost(Opcode, Dst, Src); } +int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) { + + // Make sure we were given a valid extend opcode. + assert(Opcode == Instruction::SExt || + Opcode == Instruction::ZExt && "Invalid opcode"); + + // We are extending an element we extract from a vector, so the source type + // of the extend is the element type of the vector. + auto *Src = VecTy->getElementType(); + + // Sign- and zero-extends are for integer types only. + assert(isa(Dst) && isa(Src) && "Invalid type"); + + // Get the cost for the extract. We compute the cost (if any) for the extend + // below. + auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index); + + // Legalize the types. + auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy); + auto DstVT = TLI->getValueType(DL, Dst); + auto SrcVT = TLI->getValueType(DL, Src); + + // If the resulting type is still a vector and the destination type is legal, + // we may get the extension for free. If not, get the default cost for the + // extend. + if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) + return Cost + getCastInstrCost(Opcode, Dst, Src); + + // The destination type should be larger than the element type. If not, get + // the default cost for the extend. + if (DstVT.getSizeInBits() < SrcVT.getSizeInBits()) + return Cost + getCastInstrCost(Opcode, Dst, Src); + + switch (Opcode) { + default: + llvm_unreachable("Opcode should be either SExt or ZExt"); + + // For sign-extends, we only need a smov, which performs the extension + // automatically. + case Instruction::SExt: + return Cost; + + // For zero-extends, the extend is performed automatically by a umov unless + // the destination type is i64 and the element type is i8 or i16. + case Instruction::ZExt: + if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u) + return Cost; + } + + // If we are unable to perform the extend for free, get the default cost. + return Cost + getCastInstrCost(Opcode, Dst, Src); +} + int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) { assert(Val->isVectorTy() && "This must be a vector type"); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 93a84b7..4f2e831 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -99,6 +99,9 @@ public: int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index); + int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); int getArithmeticInstrCost( diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index f1bfbc2..b92a975 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1830,11 +1830,12 @@ int BoUpSLP::getTreeCost() { if (MinBWs.count(ScalarRoot)) { auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); VecTy = VectorType::get(MinTy, BundleWidth); + ExtractCost += TTI->getExtractWithExtendCost( + Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane); + } else { ExtractCost += - TTI->getCastInstrCost(Instruction::SExt, EU.Scalar->getType(), MinTy); + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } - ExtractCost += - TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } int SpillCost = getSpillCost(); diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll index 9c06b24..d74e26e 100644 --- a/llvm/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,5 +1,5 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE -; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=GENERIC +; RUN: opt -S -mcpu=kryo -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=KRYO target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" @@ -19,13 +19,13 @@ target triple = "aarch64--linux-gnu" ; return sum; ; } -; PROFITABLE-LABEL: @gather_reduce_8x16_i32 +; GENERIC-LABEL: @gather_reduce_8x16_i32 ; -; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; PROFITABLE: sext i32 [[X]] to i64 +; GENERIC: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; GENERIC: zext <8 x i16> [[L]] to <8 x i32> +; GENERIC: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; GENERIC: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; GENERIC: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: @@ -138,18 +138,13 @@ for.body: br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body } -; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64 +; KRYO-LABEL: @gather_reduce_8x16_i64 ; -; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; UNPROFITABLE: sext i32 [[X]] to i64 -; -; TODO: Although we can now vectorize this case while converting the i64 -; subtractions to i32, the cost model currently finds vectorization to be -; unprofitable. The cost model is penalizing the sign and zero -; extensions in the vectorized version, but they are actually free. +; KRYO: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; KRYO: zext <8 x i16> [[L]] to <8 x i32> +; KRYO: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; KRYO: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; KRYO: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: -- 2.7.4