[mlir] Use VectorTransferPermutationMapLoweringPatterns in VectorToSCF
authorMatthias Springer <springerm@google.com>
Mon, 17 May 2021 05:37:32 +0000 (14:37 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 19 May 2021 05:46:19 +0000 (14:46 +0900)
VectorTransferPermutationMapLoweringPatterns can be enabled via a pass option. These additional patterns lower permutation maps to minor identity maps with broadcasting, if possible, allowing for more efficient vector load/stores. The option is deactivated by default.

Differential Revision: https://reviews.llvm.org/D102593

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir [new file with mode: 0644]
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir [deleted file]
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir

index b440578..a1c40a7 100644 (file)
@@ -521,6 +521,9 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
            "Perform full unrolling when converting vector transfers to SCF">,
     Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
            "Target vector rank to which transfer ops should be lowered">,
+    Option<"lowerPermutationMaps", "lower-permutation-maps", "bool",
+           /*default=*/"false", "Replace permutation maps with vector "
+           "transposes/broadcasts before lowering transfer ops">
   ];
 }
 
index 03765cb..a999c4a 100644 (file)
@@ -50,6 +50,7 @@ class RewritePatternSet;
 struct VectorTransferToSCFOptions {
   bool unroll = false;
   unsigned targetRank = 1;
+  bool lowerPermutationMaps = false;
 
   VectorTransferToSCFOptions &setUnroll(bool u) {
     unroll = u;
@@ -60,6 +61,11 @@ struct VectorTransferToSCFOptions {
     targetRank = r;
     return *this;
   }
+
+  VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
+    lowerPermutationMaps = l;
+    return *this;
+  }
 };
 
 /// Collect a set of patterns to convert from the Vector dialect to SCF + std.
index d0e65a1..4b5f5ce 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
@@ -86,9 +87,16 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
 /// Collect a set of transfer read/write lowering patterns.
 ///
 /// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`.
+/// `vector.store` and `vector.broadcast`. Includes all patterns of
+/// populateVectorTransferPermutationMapLoweringPatterns.
 void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes.
+void populateVectorTransferPermutationMapLoweringPatterns(
+    RewritePatternSet &patterns);
+
 /// These patterns materialize masks for various vector ops such as transfers.
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool enableIndexOptimizations);
index 9972bcf..54783a7 100644 (file)
@@ -264,6 +264,7 @@ static BufferAllocs allocBuffers(OpTy xferOp) {
   if (xferOp.mask()) {
     auto maskType = MemRefType::get({}, xferOp.mask().getType());
     auto maskBuffer = memref_alloca(maskType).value;
+    b.setInsertionPoint(xferOp);
     memref_store(xferOp.mask(), maskBuffer);
     result.maskBuffer = memref_load(maskBuffer);
   }
@@ -476,10 +477,11 @@ struct Strategy<TransferWriteOp> {
 };
 
 template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
+LogicalResult checkPrepareXferOp(OpTy xferOp,
+                                 VectorTransferToSCFOptions options) {
   if (xferOp->hasAttr(kPassLabel))
     return failure();
-  if (xferOp.getVectorType().getRank() <= targetRank)
+  if (xferOp.getVectorType().getRank() <= options.targetRank)
     return failure();
   return success();
 }
@@ -513,7 +515,7 @@ struct PrepareTransferReadConversion
 
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+    if (checkPrepareXferOp(xferOp, options).failed())
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -561,7 +563,7 @@ struct PrepareTransferWriteConversion
 
   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+    if (checkPrepareXferOp(xferOp, options).failed())
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -1160,12 +1162,23 @@ struct ConvertVectorToSCFPass
   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
     this->fullUnroll = options.unroll;
     this->targetRank = options.targetRank;
+    this->lowerPermutationMaps = options.lowerPermutationMaps;
   }
 
   void runOnFunction() override {
     VectorTransferToSCFOptions options;
-    options.setUnroll(fullUnroll);
-    options.setTargetRank(targetRank);
+    options.unroll = fullUnroll;
+    options.targetRank = targetRank;
+    options.lowerPermutationMaps = lowerPermutationMaps;
+
+    // Lower permutation maps first.
+    if (lowerPermutationMaps) {
+      RewritePatternSet lowerTransferPatterns(getFunction().getContext());
+      mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+          lowerTransferPatterns);
+      (void)applyPatternsAndFoldGreedily(getFunction(),
+                                         std::move(lowerTransferPatterns));
+    }
 
     RewritePatternSet patterns(getFunction().getContext());
     populateVectorToSCFConversionPatterns(patterns, options);
index c7a0623..f529167 100644 (file)
@@ -2934,8 +2934,8 @@ struct TransferWriteInsertPattern
 /// - The op has no mask.
 struct TransferReadToVectorLoadLowering
     : public OpRewritePattern<vector::TransferReadOp> {
-  TransferReadToVectorLoadLowering(MLIRContext *context)
-      : OpRewritePattern<vector::TransferReadOp>(context) {}
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
     SmallVector<unsigned, 4> broadcastedDims;
@@ -3009,8 +3009,8 @@ struct TransferReadToVectorLoadLowering
 /// - The op has no mask.
 struct TransferWriteToVectorStoreLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
-  TransferWriteToVectorStoreLowering(MLIRContext *context)
-      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
     // TODO: Support non-minor-identity maps
@@ -3086,6 +3086,7 @@ struct TransferReadPermutationLowering
     if (permutationMap.isIdentity())
       return failure();
 
+    permutationMap = map.getPermutationMap(permutation, op.getContext());
     // Caluclate the map of the new read by applying the inverse permutation.
     permutationMap = inversePermutation(permutationMap);
     AffineMap newMap = permutationMap.compose(map);
@@ -4149,13 +4150,18 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
                                     patterns.getContext());
 }
 
+void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadPermutationLowering,
+               TransferWritePermutationLowering, TransferOpReduceRank>(
+      patterns.getContext());
+}
+
 void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
-           TransferReadPermutationLowering, TransferWritePermutationLowering,
-           TransferOpReduceRank>(
-          patterns.getContext());
+  patterns.add<TransferReadToVectorLoadLowering,
+               TransferWriteToVectorStoreLowering>(patterns.getContext());
+  populateVectorTransferPermutationMapLoweringPatterns(patterns);
 }
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir
new file mode 100644 (file)
index 0000000..5547a79
--- /dev/null
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -split-input-file | FileCheck %s
+
+// Ensure that the permutation map is lowered (by inserting a transpose op)
+// before lowering the vector.transfer_read.
+
+// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
+//   CHECK-DAG:   %[[PADDING:.*]] = constant dense<-4.200000e+01> : vector<9xf32>
+//   CHECK-DAG:   %[[MASK:.*]] = constant dense<{{.*}}> : vector<9x4xi1>
+//       CHECK:   %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
+//       CHECK:   %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
+//       CHECK:   memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
+//       CHECK:   %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
+//       CHECK:   scf.for {{.*}} {
+//       CHECK:     scf.if {{.*}} {
+//       CHECK:       %[[MASK_LOADED:.*]] = memref.load %[[MASK_CASTED]][%{{.*}}] : memref<4xvector<9xi1>>
+//       CHECK:       %[[READ:.*]] = vector.transfer_read %{{.*}}, %{{.*}}, %[[MASK_LOADED]] : memref<?x?xf32>, vector<9xf32>
+//       CHECK:       memref.store %[[READ]], %{{.*}} : memref<4xvector<9xf32>>
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   %[[RESULT:.*]] = memref.load %{{.*}} : memref<vector<4x9xf32>>
+//       CHECK:   %[[RESULT_T:.*]] = vector.transpose %[[RESULT]], [1, 0] : vector<4x9xf32> to vector<9x4xf32>
+//       CHECK:   return %[[RESULT_T]] : vector<9x4xf32>
+
+// Vector load with mask + transpose.
+func @transfer_read_2d_mask_transposed(
+    %A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
+                          [1, 1, 1, 1], [0, 1, 1, 0],
+                          [1, 1, 1, 1], [1, 1, 1, 1],
+                          [1, 1, 1, 1], [0, 0, 0, 0],
+                          [1, 1, 1, 1]]> : vector<9x4xi1>
+  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+    memref<?x?xf32>, vector<9x4xf32>
+  return %f : vector<9x4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
deleted file mode 100644 (file)
index ad43b14..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-// Run test with and without test-vector-transfer-lowering-patterns.
-
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
-// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
-// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-
-memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
-                                                       [10., 11., 12., 13.],
-                                                       [20., 21., 22., 23.]]>
-
-// Vector load with transpose.
-func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
-  %fm42 = constant -42.0: f32
-  %f = vector.transfer_read %A[%base1, %base2], %fm42
-      {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
-    memref<?x?xf32>, vector<3x9xf32>
-  vector.print %f: vector<3x9xf32>
-  return
-}
-
-func @entry() {
-  %c0 = constant 0: index
-  %c1 = constant 1: index
-  %c2 = constant 2: index
-  %c3 = constant 3: index
-  %0 = memref.get_global @gv : memref<3x4xf32>
-  %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
-
-  // 1. Read 2D vector from 2D memref with transpose.
-  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
-  // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) )
-
-  return
-}
index 20216cc..7e6ef94 100644 (file)
@@ -3,7 +3,17 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
index 03cdc3d..e3154e6 100644 (file)
@@ -3,7 +3,17 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
index 00da927..344167b 100644 (file)
@@ -3,7 +3,17 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s