[mlir][vector] Refactor TransferReadToVectorLoadLowering
authorMatthias Springer <springerm@google.com>
Sat, 17 Jul 2021 04:52:20 +0000 (13:52 +0900)
committerMatthias Springer <springerm@google.com>
Sat, 17 Jul 2021 04:53:09 +0000 (13:53 +0900)
* TransferReadToVectorLoadLowering no longer generates memref.load ops.
* Add new pattern VectorLoadToMemrefLoadLowering that lowers scalar vector.loads to memref.loads.
* Add vector::BroadcastOp canonicalization pattern that folds broadcast chains.

Differential Revision: https://reviews.llvm.org/D106117

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

index 1674f0e..9fbc6c3 100644 (file)
@@ -1346,11 +1346,25 @@ public:
   }
 };
 
+// Fold broadcast1(broadcast2(x)) into broadcast1(x).
+struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
+    if (!srcBroadcast)
+      return failure();
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<BroadcastToShapeCast>(context);
+  results.add<BroadcastToShapeCast, BroadcastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
index c02a79d..4bd1ee1 100644 (file)
@@ -2510,32 +2510,39 @@ struct TransferReadToVectorLoadLowering
       return failure();
     if (read.mask())
       return failure();
-    Operation *loadOp;
-    if (!broadcastedDims.empty() &&
-        unbroadcastedVectorType.getNumElements() == 1) {
-      // If broadcasting is required and the number of loaded elements is 1 then
-      // we can create `memref.load` instead of `vector.load`.
-      loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
-                                               read.indices());
-    } else {
-      // Otherwise create `vector.load`.
-      loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
-                                               unbroadcastedVectorType,
-                                               read.source(), read.indices());
-    }
 
+    auto loadOp = rewriter.create<vector::LoadOp>(
+        read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
     // Insert a broadcasting op if required.
     if (!broadcastedDims.empty()) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          read, read.getVectorType(), loadOp->getResult(0));
+          read, read.getVectorType(), loadOp.result());
     } else {
-      rewriter.replaceOp(read, loadOp->getResult(0));
+      rewriter.replaceOp(read, loadOp.result());
     }
 
     return success();
   }
 };
 
+/// Replace a scalar vector.load with a memref.load.
+struct VectorLoadToMemrefLoadLowering
+    : public OpRewritePattern<vector::LoadOp> {
+  using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = loadOp.getVectorType();
+    if (vecType.getNumElements() != 1)
+      return failure();
+    auto memrefLoad = rewriter.create<memref::LoadOp>(
+        loadOp.getLoc(), loadOp.base(), loadOp.indices());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
+    return success();
+  }
+};
+
 /// Progressive lowering of transfer_write. This pattern supports lowering of
 /// `vector.transfer_write` to `vector.store` if all of the following hold:
 /// - The op writes to a memref with the default layout.
@@ -3674,8 +3681,9 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
 
 void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadToVectorLoadLowering,
-               TransferWriteToVectorStoreLowering>(patterns.getContext());
+  patterns
+      .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
+           VectorLoadToMemrefLoadLowering>(patterns.getContext());
   populateVectorTransferPermutationMapLoweringPatterns(patterns);
 }
 
index cc2d59a..81a4074 100644 (file)
@@ -613,6 +613,18 @@ func @broadcast_folding2() -> vector<4x16xi32> {
 
 // -----
 
+// CHECK-LABEL: @fold_consecutive_broadcasts(
+//  CHECK-SAME:                              %[[ARG0:.*]]: i32
+//       CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
+//       CHECK: return %[[RESULT]]
+func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
+  %1 = vector.broadcast %a : i32 to vector<16xi32>
+  %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
+  return %2 : vector<4x16xi32>
+}
+
+// -----
+
 // CHECK-LABEL: shape_cast_constant
 //       CHECK-DAG: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
index 60bbadf..931c3ba 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
 
 // transfer_read/write are lowered to vector.load/store
 // CHECK-LABEL:   func @transfer_to_load(
@@ -174,6 +174,21 @@ func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
 
 // -----
 
+// CHECK-LABEL:   func @transfer_scalar(
+// CHECK-SAME:                          %[[MEM:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                          %[[IDX:.*]]: index) -> vector<1xf32> {
+// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<1xf32>
+// CHECK-NEXT:    }
+func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
+  return %res : vector<1xf32>
+}
+
+// -----
+
 // An example with two broadcasted dimensions.
 // CHECK-LABEL:   func @transfer_broadcasting_2D(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,