Add pattern rewrite to forward vector tuple elements to their users.
authorAndy Davis <andydavis@google.com>
Tue, 17 Dec 2019 19:21:12 +0000 (11:21 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Dec 2019 19:21:45 +0000 (11:21 -0800)
User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))) -> User(Producer)

PiperOrigin-RevId: 286020249

mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-transforms.mlir

index 85f306e..569ad44 100644 (file)
@@ -583,9 +583,42 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+// Patter rewrite which forward tuple elements to their users.
+// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
+//   -> User(Producer)
+struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
+  using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
+                                     PatternRewriter &rewriter) const override {
+    // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
+    auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
+        tupleGetOp.vectors()->getDefiningOp());
+    if (!extractSlicesOp)
+      return matchFailure();
+
+    // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
+    auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
+        extractSlicesOp.vector()->getDefiningOp());
+    if (!insertSlicesOp)
+      return matchFailure();
+
+    // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
+    auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
+        insertSlicesOp.vectors()->getDefiningOp());
+    if (!tupleOp)
+      return matchFailure();
+
+    // Forward Value at tupleOp.getOperand(tupleGetOp.getIndex());
+    Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
+    rewriter.replaceOp(tupleGetOp, tupleValue);
+    return matchSuccess();
+  }
+};
+
 // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO(andydavis) Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<SplitTransferReadOp>(context);
+  patterns.insert<SplitTransferReadOp, TupleGetFolderOp>(context);
 }
index 71b7b7a..978b0c2 100644 (file)
@@ -40,27 +40,19 @@ func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32>
 
-// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]], %[[A3]], %[[A4]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
-// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
-
 // CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[R2]], [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 
 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[TG10]] : vector<2x2xf32>
+// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[A1]] : vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[TG12]] : vector<2x2xf32>
+// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[A2]] : vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[TG14]] : vector<2x2xf32>
+// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[A3]] : vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[TG16]] : vector<2x2xf32>
+// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[A4]] : vector<2x2xf32>
 
 // CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
@@ -278,7 +270,6 @@ func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
 // TODO(andydavis) Update test with VTR split transform.
 // CHECK-LABEL: func @vector_transfers
 // CHECK-COUNT-8: vector.transfer_read
-// CHECK-COUNT-2: vector.extract_slices
 // CHECK-COUNT-4: addf
 // CHECK-COUNT-1: vector.insert_slices
 //         CHECK: vector.transfer_write