From 361a27c155ec8b222e3318488a208c0eb39624c8 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 28 Sep 2022 07:57:08 -0700 Subject: [PATCH] [Hexagon] Recognize idioms for fixed-point vector multiplication Recognize Q.15*Q.15 and Q.31*Q.31, with and without rounding. --- .../Target/Hexagon/HexagonISelLowering.cpp | 14 + .../Target/Hexagon/HexagonVectorCombine.cpp | 669 ++++++++++++++++-- llvm/test/CodeGen/Hexagon/autohvx/qmul.ll | 144 ++++ 3 files changed, 782 insertions(+), 45 deletions(-) create mode 100644 llvm/test/CodeGen/Hexagon/autohvx/qmul.ll diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp index 157026c1f2fd..8723ac678bdc 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -1808,6 +1808,7 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FMUL, MVT::f64, Legal); } + setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::VSELECT); if (Subtarget.useHVXOps()) @@ -3400,6 +3401,19 @@ HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) return VSel; } } + } else if (Opc == ISD::TRUNCATE) { + SDValue Op0 = Op.getOperand(0); + // fold (truncate (build pair x, y)) -> (truncate x) or x + if (Op0.getOpcode() == ISD::BUILD_PAIR) { + MVT TruncTy = ty(Op); + SDValue Elem0 = Op0.getOperand(0); + // if we match the low element of the pair, just return it. + if (ty(Elem0) == TruncTy) + return Elem0; + // otherwise, if the low part is still too large, apply the truncate. + if (ty(Elem0).bitsGT(TruncTy)) + return DCI.DAG.getNode(ISD::TRUNCATE, dl, TruncTy, Elem0); + } } return SDValue(); diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp index 7169154e4b93..4e16d4f8b7ae 100644 --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InstSimplifyFolder.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -31,12 +32,14 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsHexagon.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/Local.h" #include "HexagonSubtarget.h" #include "HexagonTargetMachine.h" @@ -92,10 +95,14 @@ public: int getSizeOf(const Value *Val, SizeKind Kind = Store) const; int getSizeOf(const Type *Ty, SizeKind Kind = Store) const; int getTypeAlignment(Type *Ty) const; + size_t length(Value *Val) const; size_t length(Type *Ty) const; Constant *getNullValue(Type *Ty) const; Constant *getFullValue(Type *Ty) const; + Constant *getConstSplat(Type *Ty, int Val) const; + + Value *simplify(Value *Val) const; Value *insertb(IRBuilderBase &Builder, Value *Dest, Value *Src, int Start, int Length, int Where) const; @@ -110,12 +117,27 @@ public: Type *ToTy) const; Value *vlsb(IRBuilderBase &Builder, Value *Val) const; Value *vbytes(IRBuilderBase &Builder, Value *Val) const; + Value *subvector(IRBuilderBase &Builder, Value *Val, unsigned Start, + unsigned Length) const; + Value *sublo(IRBuilderBase &Builder, Value *Val) const; + Value *subhi(IRBuilderBase &Builder, Value *Val) const; + Value *vdeal(IRBuilderBase &Builder, Value *Val0, Value *Val1) const; + Value *vshuff(IRBuilderBase &Builder, Value *Val0, Value *Val1) const; Value *createHvxIntrinsic(IRBuilderBase &Builder, Intrinsic::ID IntID, Type *RetTy, ArrayRef Args) const; + SmallVector splitVectorElements(IRBuilderBase &Builder, Value *Vec, + unsigned ToWidth) const; + Value *joinVectorElements(IRBuilderBase &Builder, ArrayRef Values, + VectorType *ToType) const; std::optional calculatePointerDifference(Value *Ptr0, Value *Ptr1) const; + unsigned getNumSignificantBits(const Value *V, + const Instruction *CtxI = nullptr) const; + KnownBits getKnownBits(const Value *V, + const Instruction *CtxI = nullptr) const; + template > bool isSafeToMoveBeforeInBB(const Instruction &In, BasicBlock::const_iterator To, @@ -296,6 +318,62 @@ raw_ostream &operator<<(raw_ostream &OS, const AlignVectors::ByteSpan &BS) { return OS; } +class HvxIdioms { +public: + HvxIdioms(const HexagonVectorCombine &HVC_) : HVC(HVC_) { + auto *Int32Ty = HVC.getIntTy(32); + HvxI32Ty = HVC.getHvxTy(Int32Ty, /*Pair=*/false); + HvxP32Ty = HVC.getHvxTy(Int32Ty, /*Pair=*/true); + } + + bool run(); + +private: + struct FxpOp { + unsigned Opcode; + unsigned Frac; // Number of fraction bits + Value *X, *Y; + // If present, add 1 << RoundAt before shift: + std::optional RoundAt; + }; + + // Value + sign + // This is to distinguish multiplications: s*s, s*u, u*s, u*u. + struct SValue { + Value *Val; + bool Signed; + }; + + std::optional matchFxpMul(Instruction &In) const; + Value *processFxpMul(Instruction &In, const FxpOp &Op) const; + Value *createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y, + bool Rounding) const; + Value *createMulQ31(IRBuilderBase &Builder, Value *X, Value *Y, + bool Rounding) const; + std::pair createMul32(IRBuilderBase &Builder, SValue X, + SValue Y) const; + + VectorType *HvxI32Ty; + VectorType *HvxP32Ty; + const HexagonVectorCombine &HVC; + + friend raw_ostream &operator<<(raw_ostream &, const FxpOp &); +}; + +[[maybe_unused]] raw_ostream &operator<<(raw_ostream &OS, + const HvxIdioms::FxpOp &Op) { + OS << Instruction::getOpcodeName(Op.Opcode) << '.' << Op.Frac; + if (Op.RoundAt.has_value()) { + if (Op.Frac != 0 && Op.RoundAt.value() == Op.Frac - 1) { + OS << ":rnd"; + } else { + OS << " + 1<<" << Op.RoundAt.value(); + } + } + OS << "\n X:" << *Op.X << "\n Y:" << *Op.Y; + return OS; +} + } // namespace namespace { @@ -945,11 +1023,342 @@ auto AlignVectors::run() -> bool { // --- End AlignVectors +// --- Begin HvxIdioms + +// Match +// (X * Y) [>> N], or +// ((X * Y) + (1 << N-1)) >> N +auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional { + using namespace PatternMatch; + auto *Ty = In.getType(); + + if (!Ty->isVectorTy() || !Ty->getScalarType()->isIntegerTy()) + return std::nullopt; + + unsigned Width = cast(Ty->getScalarType())->getBitWidth(); + + FxpOp Op; + Value *Exp = &In; + + // Fixed-point multiplication is always shifted right (except when the + // fraction is 0 bits). + const APInt *Qn = nullptr; + if (Value * T; match(Exp, m_LShr(m_Value(T), m_APInt(Qn)))) { + Op.Frac = Qn->getZExtValue(); + Exp = T; + } else { + Op.Frac = 0; + } + + if (Op.Frac > Width) + return std::nullopt; + + // Check if there is rounding added. + const APInt *C = nullptr; + if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) { + unsigned CV = C->getZExtValue(); + if (CV != 0 && !isPowerOf2_32(CV)) + return std::nullopt; + if (CV != 0) + Op.RoundAt = Log2_32(CV); + Exp = T; + } + + // Check if the rest is a multiplication. + if (match(Exp, m_Mul(m_Value(Op.X), m_Value(Op.Y)))) { + Op.Opcode = Instruction::Mul; + return Op; + } + + return std::nullopt; +} + +auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const + -> Value * { + // TODO: Make it general. + if (Op.Frac != 15 && Op.Frac != 31) + return nullptr; + + auto *OrigTy = dyn_cast(Op.X->getType()); + if (OrigTy == nullptr) + return nullptr; + + unsigned BitsX = HVC.getNumSignificantBits(Op.X, &In); + unsigned BitsY = HVC.getNumSignificantBits(Op.Y, &In); + + unsigned SigBits = std::max(BitsX, BitsY); + unsigned Width = PowerOf2Ceil(SigBits); + auto *TruncTy = VectorType::get(HVC.getIntTy(Width), OrigTy); + + IRBuilder Builder(In.getParent(), In.getIterator(), + InstSimplifyFolder(HVC.DL)); + // These may end up dead, but should be removed in isel. + Value *NewX = Builder.CreateTrunc(Op.X, TruncTy); + Value *NewY = Builder.CreateTrunc(Op.Y, TruncTy); + + if (!Op.RoundAt || *Op.RoundAt == Op.Frac - 1) { + bool Rounding = Op.RoundAt.has_value(); + if (Width == Op.Frac + 1) { + Value *QMul = nullptr; + if (Width == 16) { + QMul = createMulQ15(Builder, NewX, NewY, Rounding); + } else if (Width == 32) { + QMul = createMulQ31(Builder, NewX, NewY, Rounding); + } + if (QMul != nullptr) + return Builder.CreateSExt(QMul, OrigTy); + } + } + + // FIXME: make it general, _64, addcarry + if (!HVC.HST.useHVXV62Ops()) + return nullptr; + + // The check for Frac will make sure of this, but keep this check for when + // this function handles all Frac cases. + assert(Width > 32); + if (Width > 64) + return nullptr; + + // At this point, NewX and NewY may be truncated to different element + // widths to save on the number of multiplications to perform. + unsigned WidthX = PowerOf2Ceil(BitsX); + unsigned WidthY = PowerOf2Ceil(BitsY); + NewX = Builder.CreateTrunc( + NewX, VectorType::get(HVC.getIntTy(WidthX), HVC.length(NewX), false)); + NewY = Builder.CreateTrunc( + NewY, VectorType::get(HVC.getIntTy(WidthY), HVC.length(NewY), false)); + + // Break up the arguments NewX and NewY into vectors of smaller widths + // in preparation of doing the multiplication via HVX intrinsics. + // TODO: + // Make sure that the number of elements in NewX/NewY is 32. In the future + // add generic code that will break up a (presumable long) vector into + // shorter pieces, pad the last one, then concatenate all the pieces back. + if (HVC.length(NewX) != 32) + return nullptr; + auto WordX = HVC.splitVectorElements(Builder, NewX, /*ToWidth=*/32); + auto WordY = HVC.splitVectorElements(Builder, NewY, /*ToWidth=*/32); + auto HvxWordTy = WordX[0]->getType(); + + SmallVector> Products(WordX.size() + WordY.size()); + + // WordX[i] * WordY[j] produces words i+j and i+j+1 of the results, + // that is halves 2(i+j), 2(i+j)+1, 2(i+j)+2, 2(i+j)+3. + for (int i = 0, e = WordX.size(); i != e; ++i) { + for (int j = 0, f = WordY.size(); j != f; ++j) { + bool SgnX = (i + 1 == e), SgnY = (j + 1 == f); + auto [Lo, Hi] = createMul32(Builder, {WordX[i], SgnX}, {WordY[j], SgnY}); + Products[i + j + 0].push_back(Lo); + Products[i + j + 1].push_back(Hi); + } + } + + // Add the optional rounding to the proper word. + if (Op.RoundAt.has_value()) { + Products[*Op.RoundAt / 32].push_back( + HVC.getConstSplat(HvxWordTy, 1 << (*Op.RoundAt % 32))); + } + + auto V6_vaddcarry = HVC.HST.getIntrinsicId(Hexagon::V6_vaddcarry); + Value *NoCarry = HVC.getNullValue(HVC.getBoolTy(HVC.length(HvxWordTy))); + auto pop_back_or_zero = [this, HvxWordTy](auto &Vector) -> Value * { + if (Vector.empty()) + return HVC.getNullValue(HvxWordTy); + auto Last = Vector.back(); + Vector.pop_back(); + return Last; + }; + + for (int i = 0, e = Products.size(); i != e; ++i) { + while (Products[i].size() > 1) { + Value *Carry = NoCarry; + for (int j = i; j != e; ++j) { + auto &ProdJ = Products[j]; + Value *Ret = HVC.createHvxIntrinsic( + Builder, V6_vaddcarry, nullptr, + {pop_back_or_zero(ProdJ), pop_back_or_zero(ProdJ), Carry}); + ProdJ.insert(ProdJ.begin(), Builder.CreateExtractValue(Ret, {0})); + Carry = Builder.CreateExtractValue(Ret, {1}); + } + } + } + + SmallVector WordP; + for (auto &P : Products) { + assert(P.size() == 1 && "Should have been added together"); + WordP.push_back(P.front()); + } + + // Shift all products right by Op.Frac. + unsigned SkipWords = Op.Frac / 32; + Constant *ShiftAmt = HVC.getConstSplat(HvxWordTy, Op.Frac % 32); + + for (int Dst = 0, End = WordP.size() - SkipWords; Dst != End; ++Dst) { + int Src = Dst + SkipWords; + Value *Lo = WordP[Src]; + if (Src + 1 < End) { + Value *Hi = WordP[Src + 1]; + WordP[Dst] = Builder.CreateIntrinsic(HvxWordTy, Intrinsic::fshr, + {Hi, Lo, ShiftAmt}); + } else { + // The shift of the most significant word. + WordP[Dst] = Builder.CreateAShr(Lo, ShiftAmt); + } + } + if (SkipWords != 0) + WordP.resize(WordP.size() - SkipWords); + + return HVC.joinVectorElements(Builder, WordP, OrigTy); +} + +auto HvxIdioms::createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y, + bool Rounding) const -> Value * { + assert(X->getType() == Y->getType()); + assert(X->getType()->getScalarType() == HVC.getIntTy(16)); + if (!HVC.HST.isHVXVectorType(EVT::getEVT(X->getType(), false))) + return nullptr; + + unsigned HwLen = HVC.HST.getVectorLength(); + + if (Rounding) { + auto V6_vmpyhvsrs = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhvsrs); + return HVC.createHvxIntrinsic(Builder, V6_vmpyhvsrs, X->getType(), {X, Y}); + } + // No rounding, do i16*i16 -> i32, << 1, take upper half. + auto V6_vmpyhv = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhv); + + // i16*i16 -> i32 / interleaved + Value *V1 = HVC.createHvxIntrinsic(Builder, V6_vmpyhv, HvxP32Ty, {X, Y}); + // <<1 + Value *V2 = Builder.CreateAdd(V1, V1); + // i32 -> i32 deinterleave + SmallVector DeintMask; + for (int i = 0; i != static_cast(HwLen) / 4; ++i) { + DeintMask.push_back(i); + DeintMask.push_back(i + HwLen / 4); + } + + Value *V3 = + HVC.vdeal(Builder, HVC.sublo(Builder, V2), HVC.subhi(Builder, V2)); + // High halves: i32 -> i16 + SmallVector HighMask; + for (int i = 0; i != static_cast(HwLen) / 2; ++i) { + HighMask.push_back(2 * i + 1); + } + auto *HvxP16Ty = HVC.getHvxTy(HVC.getIntTy(16), /*Pair=*/true); + Value *V4 = Builder.CreateBitCast(V3, HvxP16Ty); + return Builder.CreateShuffleVector(V4, HighMask); +} + +auto HvxIdioms::createMulQ31(IRBuilderBase &Builder, Value *X, Value *Y, + bool Rounding) const -> Value * { + assert(X->getType() == Y->getType()); + assert(X->getType()->getScalarType() == HVC.getIntTy(32)); + if (!HVC.HST.isHVXVectorType(EVT::getEVT(X->getType(), false))) + return nullptr; + + auto V6_vmpyewuh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyewuh); + auto MpyOddAcc = Rounding + ? HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_rnd_sacc) + : HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_sacc); + Value *V1 = + HVC.createHvxIntrinsic(Builder, V6_vmpyewuh, X->getType(), {X, Y}); + return HVC.createHvxIntrinsic(Builder, MpyOddAcc, X->getType(), {V1, X, Y}); +} + +auto HvxIdioms::createMul32(IRBuilderBase &Builder, SValue X, SValue Y) const + -> std::pair { + assert(X.Val->getType() == Y.Val->getType()); + assert(X.Val->getType() == HVC.getHvxTy(HVC.getIntTy(32), /*Pair=*/false)); + + assert(HVC.HST.useHVXV62Ops()); + + auto simplifyOrSame = [this](Value *V) { + if (Value *S = HVC.simplify(V)) + return S; + return V; + }; + Value *VX = simplifyOrSame(X.Val); + Value *VY = simplifyOrSame(Y.Val); + + if (isa(VX) || isa(VY)) { + auto getSplatValue = [](Constant *CV) -> ConstantInt * { + if (auto T = dyn_cast(CV)) + return dyn_cast(T->getSplatValue()); + if (auto T = dyn_cast(CV)) + return dyn_cast(T->getSplatValue()); + return nullptr; + }; + + if (isa(VX) && isa(VY)) { + // Both are constants, fold the multiplication. + auto *Ty = cast(VX->getType()); + auto *ExtTy = VectorType::getExtendedElementVectorType(Ty); + Value *EX = X.Signed ? Builder.CreateSExt(VX, ExtTy) + : Builder.CreateZExt(VX, ExtTy); + Value *EY = Y.Signed ? Builder.CreateSExt(VY, ExtTy) + : Builder.CreateZExt(VY, ExtTy); + Value *EXY = simplifyOrSame(Builder.CreateMul(EX, EY)); + auto WordXY = HVC.splitVectorElements(Builder, EXY, /*ToWidth=*/32); + return {simplifyOrSame(WordXY[0]), simplifyOrSame(WordXY[1])}; + } + // Make VX = constant. + if (isa(VY)) + std::swap(VX, VY); + + if (auto *SplatX = getSplatValue(cast(VX))) { + APInt S = SplatX->getValue(); + if (S == 1) { + if (!X.Signed && !Y.Signed) + return {VY, HVC.getConstSplat(HvxI32Ty, 0)}; + return {VY, Builder.CreateAShr(VY, HVC.getConstSplat(HvxI32Ty, 31))}; + } + } + } + + auto V6_vmpyewuh_64 = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyewuh_64); + auto V6_vmpyowh_64_acc = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_64_acc); + + Value *Vxx = + HVC.createHvxIntrinsic(Builder, V6_vmpyewuh_64, HvxP32Ty, {X.Val, Y.Val}); + Value *Vdd = HVC.createHvxIntrinsic(Builder, V6_vmpyowh_64_acc, HvxP32Ty, + {Vxx, X.Val, Y.Val}); + + return {HVC.sublo(Builder, Vdd), HVC.subhi(Builder, Vdd)}; +} + +auto HvxIdioms::run() -> bool { + bool Changed = false; + + for (BasicBlock &B : HVC.F) { + for (auto It = B.rbegin(); It != B.rend(); ++It) { + if (auto Fxm = matchFxpMul(*It)) { + Value *New = processFxpMul(*It, *Fxm); + if (!New) + continue; + bool StartOver = !isa(New); + It->replaceAllUsesWith(New); + RecursivelyDeleteTriviallyDeadInstructions(&*It, &HVC.TLI); + It = StartOver ? B.rbegin() + : cast(New)->getReverseIterator(); + } + } + } + + return Changed; +} + +// --- End HvxIdioms + auto HexagonVectorCombine::run() -> bool { if (!HST.useHVXOps()) return false; - bool Changed = AlignVectors(*this).run(); + bool Changed = false; + Changed |= AlignVectors(*this).run(); + Changed |= HvxIdioms(*this).run(); + return Changed; } @@ -1032,6 +1441,10 @@ auto HexagonVectorCombine::getTypeAlignment(Type *Ty) const -> int { return DL.getABITypeAlign(Ty).value(); } +auto HexagonVectorCombine::length(Value *Val) const -> size_t { + return length(Val->getType()); +} + auto HexagonVectorCombine::length(Type *Ty) const -> size_t { auto *VecTy = dyn_cast(Ty); assert(VecTy && "Must be a vector type"); @@ -1054,6 +1467,25 @@ auto HexagonVectorCombine::getFullValue(Type *Ty) const -> Constant * { return Minus1; } +auto HexagonVectorCombine::getConstSplat(Type *Ty, int Val) const + -> Constant * { + assert(Ty->isVectorTy()); + auto VecTy = cast(Ty); + Type *ElemTy = VecTy->getElementType(); + // Add support for floats if needed. + auto *Splat = ConstantVector::getSplat(VecTy->getElementCount(), + ConstantInt::get(ElemTy, Val)); + return Splat; +} + +auto HexagonVectorCombine::simplify(Value *V) const -> Value * { + if (auto *In = dyn_cast(V)) { + SimplifyQuery Q(DL, &TLI, &DT, &AC, In); + return simplifyInstruction(In, Q); + } + return nullptr; +} + // Insert bytes [Start..Start+Length) of Src into Dst at byte Where. auto HexagonVectorCombine::insertb(IRBuilderBase &Builder, Value *Dst, Value *Src, int Start, int Length, @@ -1263,67 +1695,201 @@ auto HexagonVectorCombine::vbytes(IRBuilderBase &Builder, Value *Val) const return Builder.CreateSExt(Val, getByteTy()); } +auto HexagonVectorCombine::subvector(IRBuilderBase &Builder, Value *Val, + unsigned Start, unsigned Length) const + -> Value * { + assert(Start + Length <= length(Val)); + return getElementRange(Builder, Val, /*Ignored*/ Val, Start, Length); +} + +auto HexagonVectorCombine::sublo(IRBuilderBase &Builder, Value *Val) const + -> Value * { + size_t Len = length(Val); + assert(Len % 2 == 0 && "Length should be even"); + return subvector(Builder, Val, 0, Len / 2); +} + +auto HexagonVectorCombine::subhi(IRBuilderBase &Builder, Value *Val) const + -> Value * { + size_t Len = length(Val); + assert(Len % 2 == 0 && "Length should be even"); + return subvector(Builder, Val, Len / 2, Len / 2); +} + +auto HexagonVectorCombine::vdeal(IRBuilderBase &Builder, Value *Val0, + Value *Val1) const -> Value * { + assert(Val0->getType() == Val1->getType()); + int Len = length(Val0); + SmallVector Mask(2 * Len); + + for (int i = 0; i != Len; ++i) { + Mask[i] = 2 * i; // Even + Mask[i + Len] = 2 * i + 1; // Odd + } + return Builder.CreateShuffleVector(Val0, Val1, Mask); +} + +auto HexagonVectorCombine::vshuff(IRBuilderBase &Builder, Value *Val0, + Value *Val1) const -> Value * { // + assert(Val0->getType() == Val1->getType()); + int Len = length(Val0); + SmallVector Mask(2 * Len); + + for (int i = 0; i != Len; ++i) { + Mask[2 * i + 0] = i; // Val0 + Mask[2 * i + 1] = i + Len; // Val1 + } + return Builder.CreateShuffleVector(Val0, Val1, Mask); +} + auto HexagonVectorCombine::createHvxIntrinsic(IRBuilderBase &Builder, Intrinsic::ID IntID, Type *RetTy, ArrayRef Args) const -> Value * { - int HwLen = HST.getVectorLength(); - Type *BoolTy = Type::getInt1Ty(F.getContext()); - Type *Int32Ty = Type::getInt32Ty(F.getContext()); - // HVX vector -> v16i32/v32i32 - // HVX vector predicate -> v512i1/v1024i1 - auto getTypeForIntrin = [&](Type *Ty) -> Type * { - if (HST.isTypeForHVX(Ty, /*IncludeBool=*/true)) { - Type *ElemTy = cast(Ty)->getElementType(); - if (ElemTy == Int32Ty) - return Ty; - if (ElemTy == BoolTy) - return VectorType::get(BoolTy, 8 * HwLen, /*Scalable*/ false); - return getHvxTy(Int32Ty); - } - // Non-HVX type. It should be a scalar. - assert(Ty == Int32Ty || Ty->isIntegerTy(64)); - return Ty; - }; - auto getCast = [&](IRBuilderBase &Builder, Value *Val, Type *DestTy) -> Value * { Type *SrcTy = Val->getType(); if (SrcTy == DestTy) return Val; - if (HST.isTypeForHVX(SrcTy, /*IncludeBool=*/true)) { - if (cast(SrcTy)->getElementType() == BoolTy) { - // This should take care of casts the other way too, for example - // v1024i1 -> v32i1. - Intrinsic::ID TC = HwLen == 64 - ? Intrinsic::hexagon_V6_pred_typecast - : Intrinsic::hexagon_V6_pred_typecast_128B; - Function *FI = Intrinsic::getDeclaration(F.getParent(), TC, - {DestTy, Val->getType()}); - return Builder.CreateCall(FI, {Val}); - } - // Non-predicate HVX vector. - return Builder.CreateBitCast(Val, DestTy); - } + // Non-HVX type. It should be a scalar, and it should already have // a valid type. - llvm_unreachable("Unexpected type"); + assert(HST.isTypeForHVX(SrcTy, /*IncludeBool=*/true)); + + Type *BoolTy = Type::getInt1Ty(F.getContext()); + if (cast(SrcTy)->getElementType() != BoolTy) + return Builder.CreateBitCast(Val, DestTy); + + // Predicate HVX vector. + unsigned HwLen = HST.getVectorLength(); + Intrinsic::ID TC = HwLen == 64 ? Intrinsic::hexagon_V6_pred_typecast + : Intrinsic::hexagon_V6_pred_typecast_128B; + Function *FI = + Intrinsic::getDeclaration(F.getParent(), TC, {DestTy, Val->getType()}); + return Builder.CreateCall(FI, {Val}); }; - SmallVector IntOps; - for (Value *A : Args) - IntOps.push_back(getCast(Builder, A, getTypeForIntrin(A->getType()))); - Function *FI = Intrinsic::getDeclaration(F.getParent(), IntID); - Value *Call = Builder.CreateCall(FI, IntOps); + Function *IntrFn = Intrinsic::getDeclaration(F.getParent(), IntID); + FunctionType *IntrTy = IntrFn->getFunctionType(); + + SmallVector IntrArgs; + for (int i = 0, e = Args.size(); i != e; ++i) { + Value *A = Args[i]; + Type *T = IntrTy->getParamType(i); + if (A->getType() != T) { + IntrArgs.push_back(getCast(Builder, A, T)); + } else { + IntrArgs.push_back(A); + } + } + Value *Call = Builder.CreateCall(IntrFn, IntrArgs); Type *CallTy = Call->getType(); - if (CallTy == RetTy) + if (RetTy == nullptr || CallTy == RetTy) return Call; // Scalar types should have RetTy matching the call return type. assert(HST.isTypeForHVX(CallTy, /*IncludeBool=*/true)); - if (cast(CallTy)->getElementType() == BoolTy) - return getCast(Builder, Call, RetTy); - return Builder.CreateBitCast(Call, RetTy); + return getCast(Builder, Call, RetTy); +} + +auto HexagonVectorCombine::splitVectorElements(IRBuilderBase &Builder, + Value *Vec, + unsigned ToWidth) const + -> SmallVector { + // Break a vector of wide elements into a series of vectors with narrow + // elements: + // (...c0:b0:a0, ...c1:b1:a1, ...c2:b2:a2, ...) + // --> + // (a0, a1, a2, ...) // lowest "ToWidth" bits + // (b0, b1, b2, ...) // the next lowest... + // (c0, c1, c2, ...) // ... + // ... + // + // The number of elements in each resulting vector is the same as + // in the original vector. + + auto *VecTy = cast(Vec->getType()); + assert(VecTy->getElementType()->isIntegerTy()); + unsigned FromWidth = VecTy->getScalarSizeInBits(); + assert(isPowerOf2_32(ToWidth) && isPowerOf2_32(FromWidth)); + + assert(ToWidth <= FromWidth && "Breaking up into wider elements?"); + unsigned NumResults = FromWidth / ToWidth; + + SmallVector Results(NumResults); + Results[0] = Vec; + unsigned Length = length(VecTy); + + // Do it by splitting in half, since those operations correspond to deal + // instructions. + auto splitInHalf = [&](unsigned Begin, unsigned End, auto splitFunc) -> void { + // Take V = Results[Begin], split it in L, H. + // Store Results[Begin] = L, Results[(Begin+End)/2] = H + // Call itself recursively split(Begin, Half), split(Half+1, End) + if (Begin + 1 == End) + return; + + Value *Val = Results[Begin]; + unsigned Width = Val->getType()->getScalarSizeInBits(); + + auto *VTy = VectorType::get(getIntTy(Width / 2), 2 * Length, false); + Value *VVal = Builder.CreateBitCast(Val, VTy); + + Value *Res = vdeal(Builder, sublo(Builder, VVal), subhi(Builder, VVal)); + + unsigned Half = (Begin + End) / 2; + Results[Begin] = sublo(Builder, Res); + Results[Half] = subhi(Builder, Res); + + splitFunc(Begin, Half, splitFunc); + splitFunc(Half, End, splitFunc); + }; + + splitInHalf(0, NumResults, splitInHalf); + return Results; +} + +auto HexagonVectorCombine::joinVectorElements(IRBuilderBase &Builder, + ArrayRef Values, + VectorType *ToType) const + -> Value * { + assert(ToType->getElementType()->isIntegerTy()); + + // If the list of values does not have power-of-2 elements, append copies + // of the sign bit to it, to make the size be 2^n. + // The reason for this is that the values will be joined in pairs, because + // otherwise the shuffles will result in convoluted code. With pairwise + // joins, the shuffles will hopefully be folded into a perfect shuffle. + // The output will need to be sign-extended to a type with element width + // being a power-of-2 anyways. + SmallVector Inputs(Values.begin(), Values.end()); + + unsigned ToWidth = ToType->getScalarSizeInBits(); + unsigned Width = Inputs.front()->getType()->getScalarSizeInBits(); + assert(Width <= ToWidth); + assert(isPowerOf2_32(Width) && isPowerOf2_32(ToWidth)); + unsigned Length = length(Inputs.front()->getType()); + + unsigned NeedInputs = ToWidth / Width; + if (Inputs.size() != NeedInputs) { + Value *Last = Inputs.back(); + Value *Sign = + Builder.CreateAShr(Last, getConstSplat(Last->getType(), Width - 1)); + Inputs.resize(NeedInputs, Sign); + } + + while (Inputs.size() > 1) { + Width *= 2; + auto *VTy = VectorType::get(getIntTy(Width), Length, false); + for (int i = 0, e = Inputs.size(); i < e; i += 2) { + Value *Res = vshuff(Builder, Inputs[i], Inputs[i + 1]); + Inputs[i / 2] = Builder.CreateBitCast(Res, VTy); + } + Inputs.resize(Inputs.size() / 2); + } + + assert(Inputs.front()->getType() == ToType); + return Inputs.front(); } auto HexagonVectorCombine::calculatePointerDifference(Value *Ptr0, @@ -1419,6 +1985,19 @@ auto HexagonVectorCombine::calculatePointerDifference(Value *Ptr0, #undef CallBuilder } +auto HexagonVectorCombine::getNumSignificantBits(const Value *V, + const Instruction *CtxI) const + -> unsigned { + return ComputeMaxSignificantBits(V, DL, /*Depth=*/0, &AC, CtxI, &DT); +} + +auto HexagonVectorCombine::getKnownBits(const Value *V, + const Instruction *CtxI) const + -> KnownBits { + return computeKnownBits(V, DL, /*Depth=*/0, &AC, CtxI, &DT, /*ORE=*/nullptr, + /*UseInstrInfo=*/true); +} + template auto HexagonVectorCombine::isSafeToMoveBeforeInBB(const Instruction &In, BasicBlock::const_iterator To, @@ -1494,7 +2073,7 @@ auto HexagonVectorCombine::isByteVecTy(Type *Ty) const -> bool { auto HexagonVectorCombine::getElementRange(IRBuilderBase &Builder, Value *Lo, Value *Hi, int Start, int Length) const -> Value * { - assert(0 <= Start && Start < Length); + assert(0 <= Start && size_t(Start + Length) < length(Lo) + length(Hi)); SmallVector SMask(Length); std::iota(SMask.begin(), SMask.end(), Start); return Builder.CreateShuffleVector(Lo, Hi, SMask); diff --git a/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll b/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll new file mode 100644 index 000000000000..00940adbe51d --- /dev/null +++ b/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll @@ -0,0 +1,144 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -march=hexagon < %s | FileCheck %s + +define void @f0(ptr %a0, ptr %a1, ptr %a2) #0 { +; CHECK-LABEL: f0: +; CHECK: // %bb.0: // %b0 +; CHECK-NEXT: { +; CHECK-NEXT: v0 = vmem(r0+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1 = vmem(r1+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v2.w = vmpye(v0.w,v1.uh) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v2.w += vmpyo(v0.w,v1.h):<<1:sat:shift +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: vmem(r2+#0) = v2 +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: jumpr r31 +; CHECK-NEXT: } +b0: + %v0 = load <32 x i32>, ptr %a0, align 128 + %v1 = load <32 x i32>, ptr %a1, align 128 + %v2 = sext <32 x i32> %v0 to <32 x i64> + %v3 = sext <32 x i32> %v1 to <32 x i64> + %v4 = mul nsw <32 x i64> %v2, %v3 + %v5 = lshr <32 x i64> %v4, + %v6 = trunc <32 x i64> %v5 to <32 x i32> + store <32 x i32> %v6, ptr %a2, align 128 + ret void +} + +define void @f1(ptr %a0, ptr %a1, ptr %a2) #0 { +; CHECK-LABEL: f1: +; CHECK: // %bb.0: // %b0 +; CHECK-NEXT: { +; CHECK-NEXT: v0 = vmem(r0+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1 = vmem(r1+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v2.w = vmpye(v0.w,v1.uh) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v2.w += vmpyo(v0.w,v1.h):<<1:rnd:sat:shift +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: vmem(r2+#0) = v2 +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: jumpr r31 +; CHECK-NEXT: } +b0: + %v0 = load <32 x i32>, ptr %a0, align 128 + %v1 = load <32 x i32>, ptr %a1, align 128 + %v2 = sext <32 x i32> %v0 to <32 x i64> + %v3 = sext <32 x i32> %v1 to <32 x i64> + %v4 = mul nsw <32 x i64> %v2, %v3 + %v5 = add nsw <32 x i64> %v4, + %v6 = lshr <32 x i64> %v5, + %v7 = trunc <32 x i64> %v6 to <32 x i32> + store <32 x i32> %v7, ptr %a2, align 128 + ret void +} + +define void @f2(ptr %a0, ptr %a1, ptr %a2) #0 { +; CHECK-LABEL: f2: +; CHECK: // %bb.0: // %b0 +; CHECK-NEXT: { +; CHECK-NEXT: v0 = vmem(r0+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: r7 = #-4 +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1 = vmem(r1+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1:0.w = vmpy(v0.h,v1.h) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1:0.w = vadd(v1:0.w,v1:0.w) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1:0 = vshuff(v1,v0,r7) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v0.h = vpacko(v1.w,v0.w) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: vmem(r2+#0) = v0 +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: jumpr r31 +; CHECK-NEXT: } +b0: + %v0 = load <64 x i16>, ptr %a0, align 128 + %v1 = load <64 x i16>, ptr %a1, align 128 + %v2 = sext <64 x i16> %v0 to <64 x i32> + %v3 = sext <64 x i16> %v1 to <64 x i32> + %v4 = mul nsw <64 x i32> %v2, %v3 + %v5 = lshr <64 x i32> %v4, + %v6 = trunc <64 x i32> %v5 to <64 x i16> + store <64 x i16> %v6, ptr %a2, align 128 + ret void +} + +define void @f3(ptr %a0, ptr %a1, ptr %a2) #0 { +; CHECK-LABEL: f3: +; CHECK: // %bb.0: // %b0 +; CHECK-NEXT: { +; CHECK-NEXT: v0 = vmem(r0+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v1 = vmem(r1+#0) +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: v0.h = vmpy(v0.h,v1.h):<<1:rnd:sat +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: vmem(r2+#0) = v0 +; CHECK-NEXT: } +; CHECK-NEXT: { +; CHECK-NEXT: jumpr r31 +; CHECK-NEXT: } +b0: + %v0 = load <64 x i16>, ptr %a0, align 128 + %v1 = load <64 x i16>, ptr %a1, align 128 + %v2 = sext <64 x i16> %v0 to <64 x i32> + %v3 = sext <64 x i16> %v1 to <64 x i32> + %v4 = mul nsw <64 x i32> %v2, %v3 + %v5 = add nsw <64 x i32> %v4, + %v6 = lshr <64 x i32> %v5, + %v7 = trunc <64 x i32> %v6 to <64 x i16> + store <64 x i16> %v7, ptr %a2, align 128 + ret void +} + +attributes #0 = { nounwind "target-features"="+v68,+hvxv68,+hvx-length128b,-packets" } -- 2.34.1