[mlir][vector] Separate out vector transfer + tensor slice patterns
authorLei Zhang <antiagainst@google.com>
Wed, 17 May 2023 15:57:13 +0000 (08:57 -0700)
committerLei Zhang <antiagainst@google.com>
Wed, 17 May 2023 16:01:19 +0000 (09:01 -0700)
These patterns touches the structure generated from tiling so it
affects later steps like bufferization and vector hoisting.
Instead of putting them in canonicalization, this commit creates
separate entry points for them to be called explicitly.

This is NFC regarding the functionality and tests of those patterns.
It also addresses two TODO items in the codebase.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D150702

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp [new file with mode: 0644]
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 3258600..2912c02 100644 (file)
@@ -218,6 +218,17 @@ void populateBreakDownVectorBitCastOpPatterns(
 void populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
+/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
+/// read and write ops.
+///
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, given the vector transfer read/write op as input.
+void populateVectorTransferTensorSliceTransforms(
+    RewritePatternSet &patterns,
+    std::function<bool(Operation *vectorOp)> controlFn = nullptr,
+    PatternBenefit benefit = 1);
+
 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
 /// `options` structure controls which operations are unrolled and the target
 /// shape.
index baabf1a..4fe9b9f 100644 (file)
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -2880,6 +2881,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
                                                        /*benefit=*/2);
   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+  vector::populateVectorTransferTensorSliceTransforms(patterns);
 
   patterns.add<CopyVectorizationPattern>(ctx);
 
index 80ea03f..1549237 100644 (file)
@@ -3692,108 +3692,7 @@ void TransferReadOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
-/// Returns true if all rank reduced in the given `extractOp` happen in leading
-/// dimensions earlier than last `trailingRank` dimensions.
-static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
-                                        unsigned trailingRank) {
-  // If no ranks are reduced at all, it's a degenerated case; always true.
-  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
-    return true;
-
-  RankedTensorType inferredType = extractOp.inferResultType(
-      extractOp.getSourceType(), extractOp.getMixedOffsets(),
-      extractOp.getMixedSizes(), extractOp.getMixedStrides());
-  return extractOp.getType().getShape().take_back(trailingRank) ==
-         inferredType.getShape().take_back(trailingRank);
-}
-
 namespace {
-/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
-///
-/// ```
-/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
-///     : tensor<?x?xf32> to tensor<?x?xf32>
-/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %p0 = arith.addi %a, %e : index
-/// %p1 = arith.addi %b, %f : index
-/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldExtractSliceIntoTransferRead
-    : public OpRewritePattern<TransferReadOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransferReadOp xferOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (!xferOp.getPermutationMap().isMinorIdentity())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
-    if (!extractOp)
-      return failure();
-    if (!extractOp.hasUnitStride())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x1x4xf32> to tensor<2x4xf32>
-    //    %1 = vector.transfer_read %0[0,0], %cst :
-    //      tensor<2x4xf32>, vector<2x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_read %t[0,0,0], %cst :
-    //      tensor<2x1x4xf32>, vector<2x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the extract_slice
-    // result tensor match the trailing dims of the inferred result tensor.
-    if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
-      return failure();
-
-    int64_t rankReduced =
-        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
-
-    SmallVector<Value> newIndices;
-    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
-    // indices first.
-    for (int64_t i = 0; i < rankReduced; ++i) {
-      OpFoldResult offset = extractOp.getMixedOffsets()[i];
-      newIndices.push_back(getValueOrCreateConstantIndexOp(
-          rewriter, extractOp.getLoc(), offset));
-    }
-    for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
-      OpFoldResult offset =
-          extractOp.getMixedOffsets()[it.index() + rankReduced];
-      newIndices.push_back(rewriter.create<arith::AddIOp>(
-          xferOp->getLoc(), it.value(),
-          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
-                                          offset)));
-    }
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferReadOp>(
-        xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
-        xferOp.getPadding(), ArrayRef<bool>{inBounds});
-
-    return success();
-  }
-};
-
 /// Store to load forwarding for transfer operations with permuation maps.
 /// Even if the permutation maps are different we can still propagate the store
 /// into the load if the size of the dimensions read and written match. Then we
@@ -3875,13 +3774,7 @@ struct TransferReadAfterWriteToBroadcast
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  // clang-format off
-  results.add <
-               // TODO: this is brittle and should be deprecated in favor of a
-               // more general pattern that applies on-demand.
-               FoldExtractSliceIntoTransferRead,
-               TransferReadAfterWriteToBroadcast>(context);
-  // clang-format on
+  results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4217,93 +4110,6 @@ public:
   }
 };
 
-/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
-/// could directly write to the insert_slice's destination. E.g.:
-///
-/// ```
-/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<4x5xf32>
-/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
-///     : tensor<4x5xf32> into tensor<?x?xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<?x?xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldInsertSliceIntoTransferWrite
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    if (!insertOp.hasUnitStride())
-      return failure();
-
-    auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
-    if (!xferOp)
-      return failure();
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    // Fold only if the TransferWriteOp completely overwrites the `source` with
-    // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
-    // content is the data of the vector.
-    if (!llvm::equal(xferOp.getVectorType().getShape(),
-                     xferOp.getShapedType().getShape()))
-      return failure();
-    if (!xferOp.getPermutationMap().isIdentity())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x4xf32>
-    //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x4xf32> into tensor<2x1x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x1x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the insert_slice result
-    // tensor match the trailing dims of the inferred result tensor.
-    int64_t rankReduced =
-        insertOp.getType().getRank() - insertOp.getSourceType().getRank();
-    int64_t vectorRank = xferOp.getVectorType().getRank();
-    RankedTensorType inferredSourceTensorType =
-        tensor::ExtractSliceOp::inferResultType(
-            insertOp.getType(), insertOp.getMixedOffsets(),
-            insertOp.getMixedSizes(), insertOp.getMixedStrides());
-    auto actualSourceTensorShape = insertOp.getSourceType().getShape();
-    if (rankReduced > 0 &&
-        actualSourceTensorShape.take_back(vectorRank) !=
-            inferredSourceTensorType.getShape().take_back(vectorRank))
-      return failure();
-
-    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
-        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
-                                                 insertOp.getDest(), indices,
-                                                 ArrayRef<bool>{inBounds});
-    return success();
-  }
-};
-
 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
 /// overwritten and inserted into another tensor. After this rewrite, the
@@ -4415,13 +4221,7 @@ public:
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  // clang-format off
-  results.add<FoldWaw,
-              // TODO: this is brittle and should be deprecated in favor of a
-              // more general pattern that applies on-demand.
-              FoldInsertSliceIntoTransferWrite,
-              SwapExtractSliceOfTransferWrite>(context);
-  // clang-format on
+  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
 //===----------------------------------------------------------------------===//
index deba915..2d269ca 100644 (file)
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorDropLeadUnitDim.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorTransferOpTransforms.cpp
+  VectorTransferTensorSliceTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
new file mode 100644 (file)
index 0000000..b3bd2cc
--- /dev/null
@@ -0,0 +1,237 @@
+//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Returns true if all rank reduced in the given `extractOp` happen in leading
+/// dimensions earlier than last `trailingRank` dimensions.
+static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
+                                        unsigned trailingRank) {
+  // If no ranks are reduced at all, it's a degenerated case; always true.
+  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
+    return true;
+
+  RankedTensorType inferredType = extractOp.inferResultType(
+      extractOp.getSourceType(), extractOp.getMixedOffsets(),
+      extractOp.getMixedSizes(), extractOp.getMixedStrides());
+  return extractOp.getType().getShape().take_back(trailingRank) ==
+         inferredType.getShape().take_back(trailingRank);
+}
+
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+///     : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = arith.addi %a, %e : index
+/// %p1 = arith.addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldExtractSliceIntoTransferRead final
+    : public OpRewritePattern<vector::TransferReadOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  FoldExtractSliceIntoTransferRead(MLIRContext *context,
+                                   std::function<bool(Operation *op)> controlFn,
+                                   PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (controlFn && !controlFn(xferOp))
+      return failure();
+
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (!xferOp.getPermutationMap().isMinorIdentity())
+      return failure();
+    if (xferOp.getMask())
+      return failure();
+    auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      return failure();
+    if (!extractOp.hasUnitStride())
+      return failure();
+
+    // Bail on illegal rank-reduction: we need to check that the rank-reduced
+    // dims are exactly the leading dims. I.e. the following is illegal:
+    // ```
+    //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
+    //      tensor<2x1x4xf32> to tensor<2x4xf32>
+    //    %1 = vector.transfer_read %0[0,0], %cst :
+    //      tensor<2x4xf32>, vector<2x4xf32>
+    // ```
+    //
+    // Cannot fold into:
+    // ```
+    //    %0 = vector.transfer_read %t[0,0,0], %cst :
+    //      tensor<2x1x4xf32>, vector<2x4xf32>
+    // ```
+    // For this, check the trailing `vectorRank` dims of the extract_slice
+    // result tensor match the trailing dims of the inferred result tensor.
+    if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
+      return failure();
+
+    int64_t rankReduced =
+        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+
+    SmallVector<Value> newIndices;
+    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+    // indices first.
+    for (int64_t i = 0; i < rankReduced; ++i) {
+      OpFoldResult offset = extractOp.getMixedOffsets()[i];
+      newIndices.push_back(getValueOrCreateConstantIndexOp(
+          rewriter, extractOp.getLoc(), offset));
+    }
+    for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
+      OpFoldResult offset =
+          extractOp.getMixedOffsets()[it.index() + rankReduced];
+      newIndices.push_back(rewriter.create<arith::AddIOp>(
+          xferOp->getLoc(), it.value(),
+          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
+                                          offset)));
+    }
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
+        xferOp.getPadding(), ArrayRef<bool>{inBounds});
+
+    return success();
+  }
+
+private:
+  std::function<bool(Operation *)> controlFn;
+};
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+///     : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldInsertSliceIntoTransferWrite final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  FoldInsertSliceIntoTransferWrite(MLIRContext *context,
+                                   std::function<bool(Operation *op)> controlFn,
+                                   PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    if (!insertOp.hasUnitStride())
+      return failure();
+
+    auto xferOp = insertOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+    if (!xferOp)
+      return failure();
+    if (controlFn && !controlFn(xferOp))
+      return failure();
+
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+      return failure();
+    if (xferOp.getMask())
+      return failure();
+    // Fold only if the TransferWriteOp completely overwrites the `source` with
+    // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
+    // content is the data of the vector.
+    if (!llvm::equal(xferOp.getVectorType().getShape(),
+                     xferOp.getShapedType().getShape()))
+      return failure();
+    if (!xferOp.getPermutationMap().isIdentity())
+      return failure();
+
+    // Bail on illegal rank-reduction: we need to check that the rank-reduced
+    // dims are exactly the leading dims. I.e. the following is illegal:
+    // ```
+    //    %0 = vector.transfer_write %v, %t[0,0], %cst :
+    //      vector<2x4xf32>, tensor<2x4xf32>
+    //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
+    //      tensor<2x4xf32> into tensor<2x1x4xf32>
+    // ```
+    //
+    // Cannot fold into:
+    // ```
+    //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
+    //      vector<2x4xf32>, tensor<2x1x4xf32>
+    // ```
+    // For this, check the trailing `vectorRank` dims of the insert_slice result
+    // tensor match the trailing dims of the inferred result tensor.
+    int64_t rankReduced =
+        insertOp.getType().getRank() - insertOp.getSourceType().getRank();
+    int64_t vectorRank = xferOp.getVectorType().getRank();
+    RankedTensorType inferredSourceTensorType =
+        tensor::ExtractSliceOp::inferResultType(
+            insertOp.getType(), insertOp.getMixedOffsets(),
+            insertOp.getMixedSizes(), insertOp.getMixedStrides());
+    auto actualSourceTensorShape = insertOp.getSourceType().getShape();
+    if (rankReduced > 0 &&
+        actualSourceTensorShape.take_back(vectorRank) !=
+            inferredSourceTensorType.getShape().take_back(vectorRank))
+      return failure();
+
+    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        insertOp, xferOp.getVector(), insertOp.getDest(), indices,
+        ArrayRef<bool>{inBounds});
+    return success();
+  }
+
+private:
+  std::function<bool(Operation *)> controlFn;
+};
+
+} // namespace
+
+void vector::populateVectorTransferTensorSliceTransforms(
+    RewritePatternSet &patterns,
+    std::function<bool(Operation *vectorOp)> controlFn,
+    PatternBenefit benefit) {
+  patterns
+      .add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
+          patterns.getContext(), controlFn, benefit);
+}
index 88c91ff..4ce4350 100644 (file)
@@ -1201,116 +1201,6 @@ func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
 
 // -----
 
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
-  return %1 : vector<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
-//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
-//       CHECK:   extract_slice
-//       CHECK:   vector.transfer_read
-func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
-//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
-//       CHECK:   return %[[r]]
-func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
-  return %1 : tensor<?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
-//       CHECK:   vector.transfer_write
-//       CHECK:   insert_slice
-func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
-  return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
-//       CHECK:   return %[[r]]
-func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
-  return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
 //       CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: func @swap_extract_slice_transfer_write
diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
new file mode 100644 (file)
index 0000000..cc17025
--- /dev/null
@@ -0,0 +1,109 @@
+// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+  return %1 : vector<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
+//       CHECK:   extract_slice
+//       CHECK:   vector.transfer_read
+func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+//       CHECK:   return %[[r]]
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+  return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
+//       CHECK:   vector.transfer_write
+//       CHECK:   insert_slice
+func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+//       CHECK:   return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}
index d0c79ab..50dfeff 100644 (file)
@@ -679,6 +679,26 @@ struct TestVectorGatherLowering
   }
 };
 
+struct TestVectorTransferTensorSlicePatterns
+    : public PassWrapper<TestVectorTransferTensorSlicePatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorTransferTensorSlicePatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-transfer-tensor-slice-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns that fold vector transfer and tensor slice ops";
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorTransferTensorSliceTransforms(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -713,6 +733,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestCreateVectorBroadcast>();
 
   PassRegistration<TestVectorGatherLowering>();
+
+  PassRegistration<TestVectorTransferTensorSlicePatterns>();
 }
 } // namespace test
 } // namespace mlir