GlobalISel: Handle more cases in getGCDType
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Sun, 7 Jun 2020 01:24:02 +0000 (21:24 -0400)
committerMatt Arsenault <Matthew.Arsenault@amd.com>
Tue, 21 Jul 2020 00:53:35 +0000 (20:53 -0400)
Try harder to find a canonical unmerge type when trying to cover the
desired target type. Handle finding a compatible unmerge type for two
vectors with different element types. This will return the largest
multiple of the source vector element that will evenly divide the
target vector type.

Also make the handling mixing scalars and vectors, and prefer the
source element type as the unmerge target type.

llvm/include/llvm/CodeGen/GlobalISel/Utils.h
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
llvm/lib/CodeGen/GlobalISel/Utils.cpp
llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

index 42d8691..066bbd8 100644 (file)
@@ -194,12 +194,23 @@ Align inferAlignFromPtrInfo(MachineFunction &MF, const MachinePointerInfo &MPO);
 /// the number of vector elements or scalar bitwidth. The intent is a
 /// G_MERGE_VALUES can be constructed from \p Ty0 elements, and unmerged into
 /// \p Ty1.
+LLVM_READNONE
 LLT getLCMType(LLT Ty0, LLT Ty1);
 
-/// Return a type that is greatest common divisor of \p OrigTy and \p
-/// TargetTy. This will either change the number of vector elements, or
-/// bitwidth of scalars. The intent is the result type can be used as the
-/// result of a G_UNMERGE_VALUES from \p OrigTy.
+/// Return a type where the total size is the greatest common divisor of \p
+/// OrigTy and \p TargetTy. This will try to either change the number of vector
+/// elements, or bitwidth of scalars. The intent is the result type can be used
+/// as the result of a G_UNMERGE_VALUES from \p OrigTy, and then some
+/// combination of G_MERGE_VALUES, G_BUILD_VECTOR and G_CONCAT_VECTORS (possibly
+/// with intermediate casts) can re-form \p TargetTy.
+///
+/// If these are vectors with different element types, this will try to produce
+/// a vector with a compatible total size, but the element type of \p OrigTy. If
+/// this can't be satisfied, this will produce a scalar smaller than the
+/// original vector elements.
+///
+/// In the worst case, this returns LLT::scalar(1)
+LLVM_READNONE
 LLT getGCDType(LLT OrigTy, LLT TargetTy);
 
 } // End namespace llvm.
index 3ec04d9..4efb50c 100644 (file)
@@ -252,7 +252,7 @@ LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
                                     LLT NarrowTy, Register SrcReg) {
   LLT SrcTy = MRI.getType(SrcReg);
 
-  LLT GCDTy = getGCDType(DstTy, getGCDType(SrcTy, NarrowTy));
+  LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
   if (SrcTy == GCDTy) {
     // If the source already evenly divides the result type, we don't need to do
     // anything.
index 8a7fb4f..6e7b334 100644 (file)
@@ -542,22 +542,45 @@ LLT llvm::getLCMType(LLT Ty0, LLT Ty1) {
 }
 
 LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
-  if (OrigTy.isVector() && TargetTy.isVector()) {
-    assert(OrigTy.getElementType() == TargetTy.getElementType());
-    int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
-                                    TargetTy.getNumElements());
-    return LLT::scalarOrVector(GCD, OrigTy.getElementType());
-  }
+  const unsigned OrigSize = OrigTy.getSizeInBits();
+  const unsigned TargetSize = TargetTy.getSizeInBits();
+
+  if (OrigSize == TargetSize)
+    return OrigTy;
+
+  if (OrigTy.isVector()) {
+    LLT OrigElt = OrigTy.getElementType();
+    if (TargetTy.isVector()) {
+      LLT TargetElt = TargetTy.getElementType();
+      if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
+        int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
+                                        TargetTy.getNumElements());
+        return LLT::scalarOrVector(GCD, OrigElt);
+      }
+    } else {
+      // If the source is a vector of pointers, return a pointer element.
+      if (OrigElt.getSizeInBits() == TargetSize)
+        return OrigElt;
+    }
+
+    unsigned GCD = greatestCommonDivisor(OrigSize, TargetSize);
+    if (GCD == OrigElt.getSizeInBits())
+      return OrigElt;
 
-  if (OrigTy.isVector() && !TargetTy.isVector()) {
-    assert(OrigTy.getElementType() == TargetTy);
-    return TargetTy;
+    // If we can't produce the original element type, we have to use a smaller
+    // scalar.
+    if (GCD < OrigElt.getSizeInBits())
+      return LLT::scalar(GCD);
+    return LLT::vector(GCD / OrigElt.getSizeInBits(), OrigElt);
   }
 
-  assert(!OrigTy.isVector() && !TargetTy.isVector() &&
-         "GCD type of vector and scalar not implemented");
+  if (TargetTy.isVector()) {
+    // Try to preserve the original element type.
+    LLT TargetElt = TargetTy.getElementType();
+    if (TargetElt.getSizeInBits() == OrigSize)
+      return OrigTy;
+  }
 
-  int GCD = greatestCommonDivisor(OrigTy.getSizeInBits(),
-                                  TargetTy.getSizeInBits());
+  unsigned GCD = greatestCommonDivisor(OrigSize, TargetSize);
   return LLT::scalar(GCD);
 }
index cd788be..6f96d69 100644 (file)
@@ -13,13 +13,18 @@ using namespace llvm;
 
 namespace {
 static const LLT S1 = LLT::scalar(1);
+static const LLT S8 = LLT::scalar(8);
 static const LLT S16 = LLT::scalar(16);
 static const LLT S32 = LLT::scalar(32);
 static const LLT S64 = LLT::scalar(64);
 static const LLT P0 = LLT::pointer(0, 64);
 static const LLT P1 = LLT::pointer(1, 32);
 
+static const LLT V2S8 = LLT::vector(2, 8);
+static const LLT V4S8 = LLT::vector(4, 8);
+
 static const LLT V2S16 = LLT::vector(2, 16);
+static const LLT V3S16 = LLT::vector(3, 16);
 static const LLT V4S16 = LLT::vector(4, 16);
 
 static const LLT V2S32 = LLT::vector(2, 32);
@@ -27,11 +32,17 @@ static const LLT V3S32 = LLT::vector(3, 32);
 static const LLT V4S32 = LLT::vector(4, 32);
 static const LLT V6S32 = LLT::vector(6, 32);
 
+static const LLT V2S64 = LLT::vector(2, 64);
+static const LLT V4S64 = LLT::vector(4, 64);
+
 static const LLT V2P0 = LLT::vector(2, P0);
 static const LLT V3P0 = LLT::vector(3, P0);
 static const LLT V4P0 = LLT::vector(4, P0);
 static const LLT V6P0 = LLT::vector(6, P0);
 
+static const LLT V2P1 = LLT::vector(2, P1);
+static const LLT V4P1 = LLT::vector(4, P1);
+
 TEST(GISelUtilsTest, getGCDType) {
   EXPECT_EQ(S1, getGCDType(S1, S1));
   EXPECT_EQ(S32, getGCDType(S32, S32));
@@ -56,7 +67,7 @@ TEST(GISelUtilsTest, getGCDType) {
   EXPECT_EQ(S32, getGCDType(P0, S32));
   EXPECT_EQ(S32, getGCDType(S32, P0));
 
-  EXPECT_EQ(S64, getGCDType(P0, S64));
+  EXPECT_EQ(P0, getGCDType(P0, S64));
   EXPECT_EQ(S64, getGCDType(S64, P0));
 
   EXPECT_EQ(S32, getGCDType(P0, P1));
@@ -64,6 +75,76 @@ TEST(GISelUtilsTest, getGCDType) {
 
   EXPECT_EQ(P0, getGCDType(V3P0, V2P0));
   EXPECT_EQ(P0, getGCDType(V2P0, V3P0));
+
+  EXPECT_EQ(P0, getGCDType(P0, V2P0));
+  EXPECT_EQ(P0, getGCDType(V2P0, P0));
+
+
+  EXPECT_EQ(V2P0, getGCDType(V2P0, V2P0));
+  EXPECT_EQ(P0, getGCDType(V3P0, V2P0));
+  EXPECT_EQ(P0, getGCDType(V2P0, V3P0));
+  EXPECT_EQ(V2P0, getGCDType(V4P0, V2P0));
+
+  EXPECT_EQ(V2P0, getGCDType(V2P0, V4P1));
+  EXPECT_EQ(V4P1, getGCDType(V4P1, V2P0));
+
+  EXPECT_EQ(V2P0, getGCDType(V4P0, V4P1));
+  EXPECT_EQ(V4P1, getGCDType(V4P1, V4P0));
+
+  // Elements have same size, but have different pointeriness, so prefer the
+  // original element type.
+  EXPECT_EQ(V2P0, getGCDType(V2P0, V4S64));
+  EXPECT_EQ(V2S64, getGCDType(V4S64, V2P0));
+
+  EXPECT_EQ(V2S16, getGCDType(V2S16, V4P1));
+  EXPECT_EQ(P1, getGCDType(V4P1, V2S16));
+  EXPECT_EQ(V2P1, getGCDType(V4P1, V4S16));
+  EXPECT_EQ(V4S16, getGCDType(V4S16, V2P1));
+
+  EXPECT_EQ(P0, getGCDType(P0, V2S64));
+  EXPECT_EQ(S64, getGCDType(V2S64, P0));
+
+  EXPECT_EQ(S16, getGCDType(V2S16, V3S16));
+  EXPECT_EQ(S16, getGCDType(V3S16, V2S16));
+  EXPECT_EQ(S16, getGCDType(V3S16, S16));
+  EXPECT_EQ(S16, getGCDType(S16, V3S16));
+
+  EXPECT_EQ(V2S16, getGCDType(V2S16, V2S32));
+  EXPECT_EQ(S32, getGCDType(V2S32, V2S16));
+
+  EXPECT_EQ(V4S8, getGCDType(V4S8, V2S32));
+  EXPECT_EQ(S32, getGCDType(V2S32, V4S8));
+
+  // Test cases where neither element type nicely divides.
+  EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(3, 5), LLT::vector(2, 6)));
+  EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(2, 6), LLT::vector(3, 5)));
+
+  // Have to go smaller than a pointer element.
+  EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(2, LLT::pointer(3, 6)),
+                                       LLT::vector(3, 5)));
+  EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(3, 5),
+                                       LLT::vector(2, LLT::pointer(3, 6))));
+
+  EXPECT_EQ(V4S8, getGCDType(V4S8, S32));
+  EXPECT_EQ(S32, getGCDType(S32, V4S8));
+  EXPECT_EQ(V4S8, getGCDType(V4S8, P1));
+  EXPECT_EQ(P1, getGCDType(P1, V4S8));
+
+  EXPECT_EQ(V2S8, getGCDType(V2S8, V4S16));
+  EXPECT_EQ(S16, getGCDType(V4S16, V2S8));
+
+  EXPECT_EQ(S8, getGCDType(V2S8, LLT::vector(4, 2)));
+  EXPECT_EQ(LLT::vector(4, 2), getGCDType(LLT::vector(4, 2), S8));
+
+
+  EXPECT_EQ(LLT::pointer(4, 8), getGCDType(LLT::vector(2, LLT::pointer(4, 8)),
+                                           LLT::vector(4, 2)));
+
+  EXPECT_EQ(LLT::vector(4, 2), getGCDType(LLT::vector(4, 2),
+                                          LLT::vector(2, LLT::pointer(4, 8))));
+
+  EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::vector(3, 4), S8));
+  EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::vector(3, 4)));
 }
 
 TEST(GISelUtilsTest, getLCMType) {