GlobalISel: Rewrite getLCMType
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Sat, 6 Jun 2020 23:37:43 +0000 (19:37 -0400)
committerMatt Arsenault <Matthew.Arsenault@amd.com>
Tue, 21 Jul 2020 01:06:30 +0000 (21:06 -0400)
Try to make the behavior more consistent with getGCDType, and bias
towards returning something closer to the source type whenever there's
an ambiguity.

llvm/include/llvm/CodeGen/GlobalISel/Utils.h
llvm/lib/CodeGen/GlobalISel/Utils.cpp
llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

index 066bbd8..8eb95a8 100644 (file)
@@ -190,12 +190,12 @@ inline bool isKnownNeverSNaN(Register Val, const MachineRegisterInfo &MRI) {
 
 Align inferAlignFromPtrInfo(MachineFunction &MF, const MachinePointerInfo &MPO);
 
-/// Return the least common multiple type of \p Ty0 and \p Ty1, by changing
-/// the number of vector elements or scalar bitwidth. The intent is a
-/// G_MERGE_VALUES can be constructed from \p Ty0 elements, and unmerged into
-/// \p Ty1.
+/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the
+/// number of vector elements or scalar bitwidth. The intent is a
+/// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from
+/// \p OrigTy elements, and unmerged into \p TargetTy
 LLVM_READNONE
-LLT getLCMType(LLT Ty0, LLT Ty1);
+LLT getLCMType(LLT OrigTy, LLT TargetTy);
 
 /// Return a type where the total size is the greatest common divisor of \p
 /// OrigTy and \p TargetTy. This will try to either change the number of vector
index 6e7b334..584a691 100644 (file)
@@ -510,35 +510,55 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
   AU.addPreserved<StackProtector>();
 }
 
-LLT llvm::getLCMType(LLT Ty0, LLT Ty1) {
-  if (!Ty0.isVector() && !Ty1.isVector()) {
-    unsigned Mul = Ty0.getSizeInBits() * Ty1.getSizeInBits();
-    int GCDSize = greatestCommonDivisor(Ty0.getSizeInBits(),
-                                        Ty1.getSizeInBits());
-    return LLT::scalar(Mul / GCDSize);
-  }
+static unsigned getLCMSize(unsigned OrigSize, unsigned TargetSize) {
+  unsigned Mul = OrigSize * TargetSize;
+  unsigned GCDSize = greatestCommonDivisor(OrigSize, TargetSize);
+  return Mul / GCDSize;
+}
 
-  if (Ty0.isVector() && !Ty1.isVector()) {
-    assert(Ty0.getElementType() == Ty1 && "not yet handled");
-    return Ty0;
-  }
+LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) {
+  const unsigned OrigSize = OrigTy.getSizeInBits();
+  const unsigned TargetSize = TargetTy.getSizeInBits();
 
-  if (Ty1.isVector() && !Ty0.isVector()) {
-    assert(Ty1.getElementType() == Ty0 && "not yet handled");
-    return Ty1;
-  }
+  if (OrigSize == TargetSize)
+    return OrigTy;
+
+  if (OrigTy.isVector()) {
+    const LLT OrigElt = OrigTy.getElementType();
+
+    if (TargetTy.isVector()) {
+      const LLT TargetElt = TargetTy.getElementType();
 
-  if (Ty0.isVector() && Ty1.isVector()) {
-    assert(Ty0.getElementType() == Ty1.getElementType() && "not yet handled");
+      if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
+        int GCDElts = greatestCommonDivisor(OrigTy.getNumElements(),
+                                            TargetTy.getNumElements());
+        // Prefer the original element type.
+        int Mul = OrigTy.getNumElements() * TargetTy.getNumElements();
+        return LLT::vector(Mul / GCDElts, OrigTy.getElementType());
+      }
+    } else {
+      if (OrigElt.getSizeInBits() == TargetSize)
+        return OrigTy;
+    }
 
-    int GCDElts = greatestCommonDivisor(Ty0.getNumElements(),
-                                        Ty1.getNumElements());
+    unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
+    return LLT::vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
+  }
 
-    int Mul = Ty0.getNumElements() * Ty1.getNumElements();
-    return LLT::vector(Mul / GCDElts, Ty0.getElementType());
+  if (TargetTy.isVector()) {
+    unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
+    return LLT::vector(LCMSize / OrigSize, OrigTy);
   }
 
-  llvm_unreachable("not yet handled");
+  unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
+
+  // Preserve pointer types.
+  if (LCMSize == OrigSize)
+    return OrigTy;
+  if (LCMSize == TargetSize)
+    return TargetTy;
+
+  return LLT::scalar(LCMSize);
 }
 
 LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
index 6f96d69..15d3bcb 100644 (file)
@@ -22,6 +22,7 @@ static const LLT P1 = LLT::pointer(1, 32);
 
 static const LLT V2S8 = LLT::vector(2, 8);
 static const LLT V4S8 = LLT::vector(4, 8);
+static const LLT V8S8 = LLT::vector(8, 8);
 
 static const LLT V2S16 = LLT::vector(2, 16);
 static const LLT V3S16 = LLT::vector(3, 16);
@@ -33,6 +34,7 @@ static const LLT V4S32 = LLT::vector(4, 32);
 static const LLT V6S32 = LLT::vector(6, 32);
 
 static const LLT V2S64 = LLT::vector(2, 64);
+static const LLT V3S64 = LLT::vector(3, 64);
 static const LLT V4S64 = LLT::vector(4, 64);
 
 static const LLT V2P0 = LLT::vector(2, P0);
@@ -157,18 +159,18 @@ TEST(GISelUtilsTest, getLCMType) {
   EXPECT_EQ(S32, getLCMType(S16, S32));
 
   EXPECT_EQ(S64, getLCMType(S64, P0));
-  EXPECT_EQ(S64, getLCMType(P0, S64));
+  EXPECT_EQ(P0, getLCMType(P0, S64));
 
-  EXPECT_EQ(S64, getLCMType(S32, P0));
-  EXPECT_EQ(S64, getLCMType(P0, S32));
+  EXPECT_EQ(P0, getLCMType(S32, P0));
+  EXPECT_EQ(P0, getLCMType(P0, S32));
 
   EXPECT_EQ(S32, getLCMType(S32, P1));
-  EXPECT_EQ(S32, getLCMType(P1, S32));
-  EXPECT_EQ(S64, getLCMType(P0, P0));
-  EXPECT_EQ(S32, getLCMType(P1, P1));
+  EXPECT_EQ(P1, getLCMType(P1, S32));
+  EXPECT_EQ(P0, getLCMType(P0, P0));
+  EXPECT_EQ(P1, getLCMType(P1, P1));
 
-  EXPECT_EQ(S64, getLCMType(P0, P1));
-  EXPECT_EQ(S64, getLCMType(P1, P0));
+  EXPECT_EQ(P0, getLCMType(P0, P1));
+  EXPECT_EQ(P0, getLCMType(P1, P0));
 
   EXPECT_EQ(V2S32, getLCMType(V2S32, V2S32));
   EXPECT_EQ(V2S32, getLCMType(V2S32, S32));
@@ -188,11 +190,55 @@ TEST(GISelUtilsTest, getLCMType) {
   EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3P0));
   EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4P0));
 
-  // FIXME
-  // EXPECT_EQ(V2S32, getLCMType(V2S32, S64));
+  EXPECT_EQ(LLT::vector(12, S64), getLCMType(V4S64, V3P0));
+  EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4S64));
 
-  // FIXME
-  //EXPECT_EQ(S64, getLCMType(S64, V2S32));
+  EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3S64));
+  EXPECT_EQ(LLT::vector(12, S64), getLCMType(V3S64, V4P0));
+
+  EXPECT_EQ(V2P0, getLCMType(V2P0, S32));
+  EXPECT_EQ(V4S32, getLCMType(S32, V2P0));
+  EXPECT_EQ(V2P0, getLCMType(V2P0, S64));
+  EXPECT_EQ(V2S64, getLCMType(S64, V2P0));
+
+
+  EXPECT_EQ(V2P0, getLCMType(V2P0, V2P1));
+  EXPECT_EQ(V4P1, getLCMType(V2P1, V2P0));
+
+  EXPECT_EQ(V2P0, getLCMType(V2P0, V4P1));
+  EXPECT_EQ(V4P1, getLCMType(V4P1, V2P0));
+
+
+  EXPECT_EQ(V2S32, getLCMType(V2S32, S64));
+  EXPECT_EQ(S64, getLCMType(S64, V2S32));
+
+  EXPECT_EQ(V4S16, getLCMType(V4S16, V2S32));
+  EXPECT_EQ(V2S32, getLCMType(V2S32, V4S16));
+
+  EXPECT_EQ(V2S32, getLCMType(V2S32, V4S8));
+  EXPECT_EQ(V8S8, getLCMType(V4S8, V2S32));
+
+  EXPECT_EQ(V2S16, getLCMType(V2S16, V4S8));
+  EXPECT_EQ(V4S8, getLCMType(V4S8, V2S16));
+
+  EXPECT_EQ(LLT::vector(6, S16), getLCMType(V3S16, V4S8));
+  EXPECT_EQ(LLT::vector(12, S8), getLCMType(V4S8, V3S16));
+  EXPECT_EQ(V4S16, getLCMType(V4S16, V4S8));
+  EXPECT_EQ(V8S8, getLCMType(V4S8, V4S16));
+
+  EXPECT_EQ(LLT::vector(6, 4), getLCMType(LLT::vector(3, 4), S8));
+  EXPECT_EQ(LLT::vector(3, 8), getLCMType(S8, LLT::vector(3, 4)));
+
+  EXPECT_EQ(LLT::vector(6, 4),
+            getLCMType(LLT::vector(3, 4), LLT::pointer(4, 8)));
+  EXPECT_EQ(LLT::vector(3, LLT::pointer(4, 8)),
+            getLCMType(LLT::pointer(4, 8), LLT::vector(3, 4)));
+
+  EXPECT_EQ(V2S64, getLCMType(V2S64, P0));
+  EXPECT_EQ(V2P0, getLCMType(P0, V2S64));
+
+  EXPECT_EQ(V2S64, getLCMType(V2S64, P1));
+  EXPECT_EQ(V4P1, getLCMType(P1, V2S64));
 }
 
 }