[mlir] VectorToSCF target rank is a pass option
authorMatthias Springer <springerm@google.com>
Fri, 14 May 2021 00:56:28 +0000 (09:56 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 14 May 2021 01:30:43 +0000 (10:30 +0900)
Make "target rank" a pass option of VectorToSCF.

Depends On D102101

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D102123

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

index b26b708..b440578 100644 (file)
@@ -519,6 +519,8 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
   let options = [
     Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
            "Perform full unrolling when converting vector transfers to SCF">,
+    Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
+           "Target vector rank to which transfer ops should be lowered">,
   ];
 }
 
index 5a42b9a..03765cb 100644 (file)
@@ -49,10 +49,17 @@ class RewritePatternSet;
 
 struct VectorTransferToSCFOptions {
   bool unroll = false;
+  unsigned targetRank = 1;
+
   VectorTransferToSCFOptions &setUnroll(bool u) {
     unroll = u;
     return *this;
   }
+
+  VectorTransferToSCFOptions &setTargetRank(unsigned r) {
+    targetRank = r;
+    return *this;
+  }
 };
 
 /// Collect a set of patterns to convert from the Vector dialect to SCF + std.
index 5b5769c..a209bc4 100644 (file)
@@ -38,8 +38,16 @@ namespace {
 /// Attribute name used for labeling transfer ops during progressive lowering.
 static const char kPassLabel[] = "__vector_to_scf_lowering__";
 
-/// Lower to 1D transfer ops. Target-specific lowering will lower those.
-static const int64_t kTargetRank = 1;
+/// Patterns that inherit from this struct have access to
+/// VectorTransferToSCFOptions.
+template <typename OpTy>
+struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
+  explicit VectorToSCFPattern(MLIRContext *context,
+                              VectorTransferToSCFOptions opt)
+      : OpRewritePattern<OpTy>(context), options(opt) {}
+
+  VectorTransferToSCFOptions options;
+};
 
 /// Given a MemRefType with VectorType element type, unpack one dimension from
 /// the VectorType into the MemRefType.
@@ -270,8 +278,9 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
 /// Add the pass label to a vector transfer op if its rank is not the target
 /// rank.
 template <typename OpTy>
-static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) {
-  if (newXferOp.getVectorType().getRank() > kTargetRank)
+static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
+                                unsigned targetRank) {
+  if (newXferOp.getVectorType().getRank() > targetRank)
     newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
 }
 
@@ -347,8 +356,10 @@ struct Strategy<TransferReadOp> {
   /// Note: The loop and type cast are generated in TransferOpConversion.
   ///       The original TransferReadOp and store op are deleted in `cleanup`.
   /// Note: The `mask` operand is set in TransferOpConversion.
-  static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
-                                  Value buffer, Value iv) {
+  static TransferReadOp rewriteOp(OpBuilder &builder,
+                                  VectorTransferToSCFOptions options,
+                                  TransferReadOp xferOp, Value buffer,
+                                  Value iv) {
     SmallVector<Value, 8> storeIndices;
     getBufferIndices(xferOp, storeIndices);
     storeIndices.push_back(iv);
@@ -367,7 +378,8 @@ struct Strategy<TransferReadOp> {
             .value;
 
     maybeApplyPassLabel(builder,
-                        dyn_cast<TransferReadOp>(newXfer.getDefiningOp()));
+                        dyn_cast<TransferReadOp>(newXfer.getDefiningOp()),
+                        options.targetRank);
 
     memref_store(newXfer, buffer, storeIndices);
     return newXfer.getDefiningOp<TransferReadOp>();
@@ -428,8 +440,10 @@ struct Strategy<TransferWriteOp> {
   ///    to memory.
   ///
   /// Note: For more details, see comments on Strategy<TransferReadOp>.
-  static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
-                                   Value buffer, Value iv) {
+  static TransferWriteOp rewriteOp(OpBuilder &builder,
+                                   VectorTransferToSCFOptions options,
+                                   TransferWriteOp xferOp, Value buffer,
+                                   Value iv) {
     SmallVector<Value, 8> loadIndices;
     getBufferIndices(xferOp, loadIndices);
     loadIndices.push_back(iv);
@@ -444,7 +458,7 @@ struct Strategy<TransferWriteOp> {
         AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
         inBoundsAttr);
 
-    maybeApplyPassLabel(builder, newXfer.op);
+    maybeApplyPassLabel(builder, newXfer.op, options.targetRank);
 
     return newXfer;
   }
@@ -460,10 +474,10 @@ struct Strategy<TransferWriteOp> {
 };
 
 template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp) {
+LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
   if (xferOp->hasAttr(kPassLabel))
     return failure();
-  if (xferOp.getVectorType().getRank() <= kTargetRank)
+  if (xferOp.getVectorType().getRank() <= targetRank)
     return failure();
   return success();
 }
@@ -491,12 +505,13 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
 /// ```
 ///
 /// Note: A second temporary buffer may be allocated for the `mask` operand.
-struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
-  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct PrepareTransferReadConversion
+    : public VectorToSCFPattern<TransferReadOp> {
+  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
 
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (checkPrepareXferOp(xferOp).failed())
+    if (checkPrepareXferOp(xferOp, options.targetRank).failed())
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -539,12 +554,12 @@ struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
 ///
 /// Note: A second temporary buffer may be allocated for the `mask` operand.
 struct PrepareTransferWriteConversion
-    : public OpRewritePattern<TransferWriteOp> {
-  using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+    : public VectorToSCFPattern<TransferWriteOp> {
+  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
 
   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (checkPrepareXferOp(xferOp).failed())
+    if (checkPrepareXferOp(xferOp, options.targetRank).failed())
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -583,8 +598,8 @@ struct PrepareTransferWriteConversion
 ///    out-of-bounds, generate an if-check and handle both cases separately.
 /// 3. Clean up according to the corresponding Strategy<OpTy>.
 template <typename OpTy>
-struct TransferOpConversion : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
+  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
 
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
@@ -635,8 +650,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
               /*inBoundsCase=*/
               [&](OpBuilder &b, Location /*loc*/) {
                 // Create new transfer op.
-                OpTy newXfer =
-                    Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
+                OpTy newXfer = Strategy<OpTy>::rewriteOp(
+                    b, this->options, xferOp, castedDataBuffer, iv);
 
                 // If old transfer op has a mask: Set mask on new transfer op.
                 // Special case: If the mask of the old transfer op is 1D and
@@ -731,8 +746,9 @@ static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
 /// Note: As an optimization, if the result of the original TransferReadOp
 /// was directly inserted into another vector, no new %v_init vector is created.
 /// Instead, the new TransferReadOp results are inserted into that vector.
-struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
-  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct UnrollTransferReadConversion
+    : public VectorToSCFPattern<TransferReadOp> {
+  using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
 
   /// Return the vector into which the newly created TransferReadOp results
   /// are inserted.
@@ -770,7 +786,7 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
   /// accesses, and broadcasts and transposes in permutation maps.
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (xferOp.getVectorType().getRank() <= kTargetRank)
+    if (xferOp.getVectorType().getRank() <= options.targetRank)
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -861,8 +877,8 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
 /// doing so, `a` may become dead, and the number of ExtractOps generated during
 /// recursive application of this pattern will be minimal.
 struct UnrollTransferWriteConversion
-    : public OpRewritePattern<TransferWriteOp> {
-  using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+    : public VectorToSCFPattern<TransferWriteOp> {
+  using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
 
   /// Return the vector from which newly generated ExtracOps will extract.
   Value getDataVector(TransferWriteOp xferOp) const {
@@ -893,7 +909,7 @@ struct UnrollTransferWriteConversion
   /// accesses, and broadcasts and transposes in permutation maps.
   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (xferOp.getVectorType().getRank() <= kTargetRank)
+    if (xferOp.getVectorType().getRank() <= options.targetRank)
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
@@ -1062,8 +1078,8 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
 /// }
 /// ```
 template <typename OpTy>
-struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
+  using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
 
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
@@ -1106,17 +1122,18 @@ void populateVectorToSCFConversionPatterns(
     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
   if (options.unroll) {
     patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
-        patterns.getContext());
+        patterns.getContext(), options);
   } else {
     patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
                  TransferOpConversion<TransferReadOp>,
-                 TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+                 TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
+                                                        options);
   }
 
-  if (kTargetRank == 1) {
+  if (options.targetRank == 1) {
     patterns.add<TransferOp1dConversion<TransferReadOp>,
-                 TransferOp1dConversion<TransferWriteOp>>(
-        patterns.getContext());
+                 TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
+                                                          options);
   }
 }
 
@@ -1129,12 +1146,16 @@ struct ConvertVectorToSCFPass
   ConvertVectorToSCFPass() = default;
   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
     this->fullUnroll = options.unroll;
+    this->targetRank = options.targetRank;
   }
 
   void runOnFunction() override {
+    VectorTransferToSCFOptions options;
+    options.setUnroll(fullUnroll);
+    options.setTargetRank(targetRank);
+
     RewritePatternSet patterns(getFunction().getContext());
-    populateVectorToSCFConversionPatterns(
-        patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+    populateVectorToSCFConversionPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };