From 1ef3ed0eb49d39d6eda84dec7e702aee7f55e9ae Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Sat, 6 Jun 2020 19:37:43 -0400 Subject: [PATCH] GlobalISel: Rewrite getLCMType Try to make the behavior more consistent with getGCDType, and bias towards returning something closer to the source type whenever there's an ambiguity. --- llvm/include/llvm/CodeGen/GlobalISel/Utils.h | 10 ++-- llvm/lib/CodeGen/GlobalISel/Utils.cpp | 64 +++++++++++++------- .../CodeGen/GlobalISel/GISelUtilsTest.cpp | 70 ++++++++++++++++++---- 3 files changed, 105 insertions(+), 39 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index 066bbd8..8eb95a8 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -190,12 +190,12 @@ inline bool isKnownNeverSNaN(Register Val, const MachineRegisterInfo &MRI) { Align inferAlignFromPtrInfo(MachineFunction &MF, const MachinePointerInfo &MPO); -/// Return the least common multiple type of \p Ty0 and \p Ty1, by changing -/// the number of vector elements or scalar bitwidth. The intent is a -/// G_MERGE_VALUES can be constructed from \p Ty0 elements, and unmerged into -/// \p Ty1. +/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the +/// number of vector elements or scalar bitwidth. The intent is a +/// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from +/// \p OrigTy elements, and unmerged into \p TargetTy LLVM_READNONE -LLT getLCMType(LLT Ty0, LLT Ty1); +LLT getLCMType(LLT OrigTy, LLT TargetTy); /// Return a type where the total size is the greatest common divisor of \p /// OrigTy and \p TargetTy. This will try to either change the number of vector diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 6e7b334..584a691 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -510,35 +510,55 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) { AU.addPreserved(); } -LLT llvm::getLCMType(LLT Ty0, LLT Ty1) { - if (!Ty0.isVector() && !Ty1.isVector()) { - unsigned Mul = Ty0.getSizeInBits() * Ty1.getSizeInBits(); - int GCDSize = greatestCommonDivisor(Ty0.getSizeInBits(), - Ty1.getSizeInBits()); - return LLT::scalar(Mul / GCDSize); - } +static unsigned getLCMSize(unsigned OrigSize, unsigned TargetSize) { + unsigned Mul = OrigSize * TargetSize; + unsigned GCDSize = greatestCommonDivisor(OrigSize, TargetSize); + return Mul / GCDSize; +} - if (Ty0.isVector() && !Ty1.isVector()) { - assert(Ty0.getElementType() == Ty1 && "not yet handled"); - return Ty0; - } +LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) { + const unsigned OrigSize = OrigTy.getSizeInBits(); + const unsigned TargetSize = TargetTy.getSizeInBits(); - if (Ty1.isVector() && !Ty0.isVector()) { - assert(Ty1.getElementType() == Ty0 && "not yet handled"); - return Ty1; - } + if (OrigSize == TargetSize) + return OrigTy; + + if (OrigTy.isVector()) { + const LLT OrigElt = OrigTy.getElementType(); + + if (TargetTy.isVector()) { + const LLT TargetElt = TargetTy.getElementType(); - if (Ty0.isVector() && Ty1.isVector()) { - assert(Ty0.getElementType() == Ty1.getElementType() && "not yet handled"); + if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) { + int GCDElts = greatestCommonDivisor(OrigTy.getNumElements(), + TargetTy.getNumElements()); + // Prefer the original element type. + int Mul = OrigTy.getNumElements() * TargetTy.getNumElements(); + return LLT::vector(Mul / GCDElts, OrigTy.getElementType()); + } + } else { + if (OrigElt.getSizeInBits() == TargetSize) + return OrigTy; + } - int GCDElts = greatestCommonDivisor(Ty0.getNumElements(), - Ty1.getNumElements()); + unsigned LCMSize = getLCMSize(OrigSize, TargetSize); + return LLT::vector(LCMSize / OrigElt.getSizeInBits(), OrigElt); + } - int Mul = Ty0.getNumElements() * Ty1.getNumElements(); - return LLT::vector(Mul / GCDElts, Ty0.getElementType()); + if (TargetTy.isVector()) { + unsigned LCMSize = getLCMSize(OrigSize, TargetSize); + return LLT::vector(LCMSize / OrigSize, OrigTy); } - llvm_unreachable("not yet handled"); + unsigned LCMSize = getLCMSize(OrigSize, TargetSize); + + // Preserve pointer types. + if (LCMSize == OrigSize) + return OrigTy; + if (LCMSize == TargetSize) + return TargetTy; + + return LLT::scalar(LCMSize); } LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) { diff --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp index 6f96d69..15d3bcb 100644 --- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp @@ -22,6 +22,7 @@ static const LLT P1 = LLT::pointer(1, 32); static const LLT V2S8 = LLT::vector(2, 8); static const LLT V4S8 = LLT::vector(4, 8); +static const LLT V8S8 = LLT::vector(8, 8); static const LLT V2S16 = LLT::vector(2, 16); static const LLT V3S16 = LLT::vector(3, 16); @@ -33,6 +34,7 @@ static const LLT V4S32 = LLT::vector(4, 32); static const LLT V6S32 = LLT::vector(6, 32); static const LLT V2S64 = LLT::vector(2, 64); +static const LLT V3S64 = LLT::vector(3, 64); static const LLT V4S64 = LLT::vector(4, 64); static const LLT V2P0 = LLT::vector(2, P0); @@ -157,18 +159,18 @@ TEST(GISelUtilsTest, getLCMType) { EXPECT_EQ(S32, getLCMType(S16, S32)); EXPECT_EQ(S64, getLCMType(S64, P0)); - EXPECT_EQ(S64, getLCMType(P0, S64)); + EXPECT_EQ(P0, getLCMType(P0, S64)); - EXPECT_EQ(S64, getLCMType(S32, P0)); - EXPECT_EQ(S64, getLCMType(P0, S32)); + EXPECT_EQ(P0, getLCMType(S32, P0)); + EXPECT_EQ(P0, getLCMType(P0, S32)); EXPECT_EQ(S32, getLCMType(S32, P1)); - EXPECT_EQ(S32, getLCMType(P1, S32)); - EXPECT_EQ(S64, getLCMType(P0, P0)); - EXPECT_EQ(S32, getLCMType(P1, P1)); + EXPECT_EQ(P1, getLCMType(P1, S32)); + EXPECT_EQ(P0, getLCMType(P0, P0)); + EXPECT_EQ(P1, getLCMType(P1, P1)); - EXPECT_EQ(S64, getLCMType(P0, P1)); - EXPECT_EQ(S64, getLCMType(P1, P0)); + EXPECT_EQ(P0, getLCMType(P0, P1)); + EXPECT_EQ(P0, getLCMType(P1, P0)); EXPECT_EQ(V2S32, getLCMType(V2S32, V2S32)); EXPECT_EQ(V2S32, getLCMType(V2S32, S32)); @@ -188,11 +190,55 @@ TEST(GISelUtilsTest, getLCMType) { EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3P0)); EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4P0)); - // FIXME - // EXPECT_EQ(V2S32, getLCMType(V2S32, S64)); + EXPECT_EQ(LLT::vector(12, S64), getLCMType(V4S64, V3P0)); + EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4S64)); - // FIXME - //EXPECT_EQ(S64, getLCMType(S64, V2S32)); + EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3S64)); + EXPECT_EQ(LLT::vector(12, S64), getLCMType(V3S64, V4P0)); + + EXPECT_EQ(V2P0, getLCMType(V2P0, S32)); + EXPECT_EQ(V4S32, getLCMType(S32, V2P0)); + EXPECT_EQ(V2P0, getLCMType(V2P0, S64)); + EXPECT_EQ(V2S64, getLCMType(S64, V2P0)); + + + EXPECT_EQ(V2P0, getLCMType(V2P0, V2P1)); + EXPECT_EQ(V4P1, getLCMType(V2P1, V2P0)); + + EXPECT_EQ(V2P0, getLCMType(V2P0, V4P1)); + EXPECT_EQ(V4P1, getLCMType(V4P1, V2P0)); + + + EXPECT_EQ(V2S32, getLCMType(V2S32, S64)); + EXPECT_EQ(S64, getLCMType(S64, V2S32)); + + EXPECT_EQ(V4S16, getLCMType(V4S16, V2S32)); + EXPECT_EQ(V2S32, getLCMType(V2S32, V4S16)); + + EXPECT_EQ(V2S32, getLCMType(V2S32, V4S8)); + EXPECT_EQ(V8S8, getLCMType(V4S8, V2S32)); + + EXPECT_EQ(V2S16, getLCMType(V2S16, V4S8)); + EXPECT_EQ(V4S8, getLCMType(V4S8, V2S16)); + + EXPECT_EQ(LLT::vector(6, S16), getLCMType(V3S16, V4S8)); + EXPECT_EQ(LLT::vector(12, S8), getLCMType(V4S8, V3S16)); + EXPECT_EQ(V4S16, getLCMType(V4S16, V4S8)); + EXPECT_EQ(V8S8, getLCMType(V4S8, V4S16)); + + EXPECT_EQ(LLT::vector(6, 4), getLCMType(LLT::vector(3, 4), S8)); + EXPECT_EQ(LLT::vector(3, 8), getLCMType(S8, LLT::vector(3, 4))); + + EXPECT_EQ(LLT::vector(6, 4), + getLCMType(LLT::vector(3, 4), LLT::pointer(4, 8))); + EXPECT_EQ(LLT::vector(3, LLT::pointer(4, 8)), + getLCMType(LLT::pointer(4, 8), LLT::vector(3, 4))); + + EXPECT_EQ(V2S64, getLCMType(V2S64, P0)); + EXPECT_EQ(V2P0, getLCMType(P0, V2S64)); + + EXPECT_EQ(V2S64, getLCMType(V2S64, P1)); + EXPECT_EQ(V4P1, getLCMType(P1, V2S64)); } } -- 2.7.4