Fix a corner case in vector.shape_cast when the trailing dimensions are of size 1.
authorWen-Heng (Jack) Chung <whchung@gmail.com>
Mon, 22 Jun 2020 14:52:12 +0000 (09:52 -0500)
committerWen-Heng (Jack) Chung <whchung@gmail.com>
Tue, 23 Jun 2020 03:00:45 +0000 (22:00 -0500)
Differential Revision: https://reviews.llvm.org/D82304

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/ops.mlir

index 019f5fd..5d3a916 100644 (file)
@@ -1633,6 +1633,14 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
     if (dimA != dimB)
       break;
     ++i;
+
+    // Handle the case when trailing dimensions are of size 1.
+    // Include them into the contiguous sequence.
+    auto isOne = [](int64_t v) { return v == 1; };
+    if (i < rankA && llvm::all_of(a.slice(i), isOne))
+      i = rankA;
+    if (j < rankB && llvm::all_of(b.slice(j), isOne))
+      j = rankB;
   }
 
   return i == rankA && j == rankB;
index 02ee4dd..4ea7286 100644 (file)
@@ -266,8 +266,10 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
 
 // CHECK-LABEL: @shape_cast
 func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
-                 %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>)
-  -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>) {
+                 %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>,
+                 %arg2 : vector<8x1xf32>,
+                 %arg3 : vector<16x1x1xf32>)
+  -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>) {
 
   // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32>
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
@@ -276,7 +278,16 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
   %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                  tuple<vector<20x2xf32>, vector<12x2xf32>>
 
-  return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x1xf32> to vector<8xf32>
+  %2 = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32>
+
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16xf32>
+  %3 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16xf32>
+
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16x1xf32>
+  %4 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16x1xf32>
+
+  return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
 }
 
 // CHECK-LABEL: @vector_fma