[MLIR][Vector] Update ShapeCastOp folder to use producer-consumer value forwarding.
authorAndy Davis <andydavis@google.com>
Wed, 8 Apr 2020 15:39:48 +0000 (08:39 -0700)
committerAndy Davis <andydavis@google.com>
Wed, 8 Apr 2020 15:55:37 +0000 (08:55 -0700)
Summary:
Update ShapeCastOp folder to use producer-consumer value forwarding.
Support is added for tracking sub-vectors through trivial shape cast operations,
where the sub-vector shape is preserved across shape cast operations and only
leading ones are added or removed.
Support is preserved for cancelling shape cast operations.
One unit test is added and two are updated.

Reviewers: aartbik, nicolasvasilache

Reviewed By: aartbik, nicolasvasilache

Subscribers: frgossen, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77253

mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir

index dbb0bf4..7a197ef 100644 (file)
@@ -676,10 +676,10 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
 
 /// Returns the producer Value of the same type as 'consumerValue', by tracking
 /// the tuple index and offsets of the consumer vector value through the
-/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
-/// from consumer to producer. Each operation in the chain is structured, and
-/// so the tuple index and offsets can be mapped from result to input, while
-/// visiting each operation in the chain.
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
+/// and ShapeCastOp) from consumer to producer. Each operation in the chain is
+/// structured, and so the tuple index and offsets can be mapped from result to
+/// input, while visiting each operation in the chain.
 /// Returns nullptr on failure.
 static Value getProducerValue(Value consumerValue) {
   auto consumerVectorType = consumerValue.getType().cast<VectorType>();
@@ -760,8 +760,57 @@ static Value getProducerValue(Value consumerValue) {
       // Update 'tupleIndex' and next defining 'op' to visit.
       tupleIndex = -1;
       op = value.getDefiningOp();
+    } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
+      if (shapeCastOp.source().getType().isa<TupleType>())
+        return nullptr;
+      assert(tupleIndex == -1);
+      auto sourceVectorType = shapeCastOp.getSourceVectorType();
+      auto sourceVectorShape = sourceVectorType.getShape();
+      unsigned sourceVectorRank = sourceVectorType.getRank();
+      auto resultVectorType = shapeCastOp.getResultVectorType();
+      auto resultVectorShape = resultVectorType.getShape();
+      unsigned resultVectorRank = resultVectorType.getRank();
+
+      int i = sourceVectorRank - 1;
+      int j = resultVectorRank - 1;
+
+      // Check that source/result vector shape prefixes match while
+      // updating 'newOffsets'.
+      bool canShapeCastFold = true;
+      SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
+
+      auto apply = [&](int64_t sourceSize, int64_t resultSize) {
+        canShapeCastFold = sourceSize == resultSize;
+        newOffsets[i--] = offsets[j--];
+      };
+      functional::zipApply(apply, llvm::reverse(sourceVectorShape),
+                           llvm::reverse(resultVectorShape));
+      if (!canShapeCastFold)
+        return nullptr;
+
+      // Check that remaining prefix of source/result vector shapes are all 1s.
+      // Currently we only support producer/consumer tracking through trivial
+      // shape cast ops. Examples:
+      //   %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
+      //   %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
+      assert(i == -1 || j == -1);
+      if (i >= 0 &&
+          !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
+                       [](int64_t v) { return v == 1; }))
+        return nullptr;
+      if (j >= 0 &&
+          !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
+                       [](int64_t v) { return v == 1; }))
+        return nullptr;
+
+      offsets.swap(newOffsets);
+      op = shapeCastOp.source().getDefiningOp();
     } else {
-      break;
+      // Check if 'op' produces a Value with the same type as 'consumerValue'.
+      if (op->getNumResults() == 1 &&
+          op->getResult(0).getType() == consumerVectorType)
+        return op->getResult(0);
+      return nullptr;
     }
   }
   return nullptr;
@@ -788,6 +837,12 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
+    // Check if we can replace 'shapeCastOp' result with its producer.
+    if (auto producer = getProducerValue(shapeCastOp.getResult())) {
+      rewriter.replaceOp(shapeCastOp, producer);
+      return success();
+    }
+
     // Check if 'shapeCastOp' has vector source/result type.
     auto sourceVectorType =
         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
index 082afba..2e4e903 100644 (file)
@@ -341,16 +341,21 @@ func @tuple_get_producer_consumer(
   %2 = vector.extract_slices %1, [4, 8], [1, 1]
     : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
   // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
-  %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
-  %4 = vector.extract_slices %3, [2, 4], [1, 1]
+  %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+                              tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
+  %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4]
+  %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32>
+  // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4]
+  %6 = vector.extract_slices %5, [2, 4], [1, 1]
     : vector<4x8xf32> into
       tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
-  %5 = vector.tuple_get %4, 3
+  // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0]
+  %7 = vector.tuple_get %6, 3
     : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %5
-  return %5 : vector<2x4xf32>
+  // %arg7 == %7
+  return %7 : vector<2x4xf32>
 }
 
 // CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
@@ -381,25 +386,40 @@ func @tuple_get_producer_consumer_swizzle(
   %2 = vector.extract_slices %1, [4, 8], [1, 1]
     : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
   // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+  %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+                             tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
 
   // Extract tuple elements.
-  %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+  %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4]
 
   // Swizzle tuple elements.
-  %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
-  // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
-  %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
-  %7 = vector.extract_slices %6, [2, 4], [1, 1]
+  %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32>
+  // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4]
+  %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to
+                              tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4]
+  %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4]
+  %9 = vector.extract_slices %8, [2, 4], [1, 1]
     : vector<4x8xf32> into
       tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
-  %8 = vector.tuple_get %7, 3
+  // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0]
+  %10 = vector.tuple_get %9, 3
     : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %8
-  return %8 : vector<2x4xf32>
+  // %arg7 == %10
+  return %10 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @cancelling_shape_cast_ops
+//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+//       CHECK: return %[[A0]] : vector<2x4xf32>
+func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
 }
 
 // CHECK-LABEL: func @vector_transfers_vector_element_type