[mlir][Tensor] Add a FoldTensorSubsetOps pass and patterns
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 21 Mar 2023 21:41:20 +0000 (14:41 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 23 Mar 2023 11:03:27 +0000 (04:03 -0700)
These patterns follow FoldMemRefAliasOps which is further refactored for reuse.
In the process, fix FoldMemRefAliasOps handling of strides for vector.transfer ops which was previously incorrect.

These opt-in patterns generalize the existing canonicalizations on vector.transfer ops.
In the future the blanket canonicalizations will be retired.
They are kept for now to minimize porting disruptions.

Differential Revision: https://reviews.llvm.org/D146624

20 files changed:
mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp [new file with mode: 0644]
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 3fac940..42156ac 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
+class RewriterBase;
 
 /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
 /// when combining a producer slice **into** a consumer slice.
@@ -21,6 +22,7 @@ namespace mlir {
 /// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
 /// - Combined sizes = consumer_sizes
 /// - Combined strides = producer_strides * consumer_strides
+// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
 LogicalResult
 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             ArrayRef<OpFoldResult> producerOffsets,
@@ -36,6 +38,7 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
 
 /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
 /// when combining a `producer` slice op **into** a `consumer` slice op.
+// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
 LogicalResult
 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             OffsetSizeAndStrideOpInterface producer,
@@ -45,6 +48,30 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             SmallVector<OpFoldResult> &combinedSizes,
                             SmallVector<OpFoldResult> &combinedStrides);
 
+/// Given the 'indicesVals' of a load/store operation operating on an op with
+/// offsets and strides, return the combined indices.
+///
+/// For example, using `memref.load` and `memref.subview` as an illustration:
+///
+/// ```
+///    %0 = ... : memref<12x42xf32>
+///    %1 = memref.subview %0[%arg0, %arg1][...][%stride1, %stride2] :
+///      memref<12x42xf32> to memref<4x4xf32, offset=?, strides=[?, ?]>
+///    %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+/// ```
+///
+/// could be folded into:
+///
+/// ```
+///    %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+///         memref<12x42xf32>
+/// ```
+void resolveSourceIndicesOffsetsAndStrides(
+    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedStrides,
+    const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
+    SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
index 66d6dcc..721615f 100644 (file)
@@ -858,6 +858,10 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
       return {rank, rank, rank};
     }
 
+    /// Return the dimensions of the dest that are omitted to insert a source
+    /// when the result is rank-extended.
+    llvm::SmallBitVector getDroppedDims();
+
     /// Return the number of leading operands before the `offsets`, `sizes` and
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
index df695db..48f9066 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+namespace tensor {
 
-#define GEN_PASS_DECL
-#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
 
-/// Creates an instance of `tensor` dialect bufferization pass.
+/// Creates an instance of the `tensor` subset folding pass.
+std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
+
+/// Creates an instance of the `tensor` dialect bufferization pass.
 std::unique_ptr<Pass> createTensorBufferizePass();
 
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
 
-namespace tensor {
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
-} // namespace tensor
 
+} // namespace tensor
 } // namespace mlir
 
 #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
index 2bf774d..b467359 100644 (file)
 
 include "mlir/Pass/PassBase.td"
 
+def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
+  let summary = "Fold tensor subset ops into producer/consumer ops";
+  let description = [{
+    The pass folds tensor subset ops into producer/consumer ops.
+
+    At the moment, the following foldings occur when possible:
+      - tensor.extract_slice into vector.transfer_read
+      - vector.transfer_write into tensor.insert_slice
+
+  }];
+  let constructor = "mlir::tensor::createFoldTensorSubsetOpsPass()";
+  let dependentDialects = [
+      "AffineDialect", "tensor::TensorDialect", "vector::VectorDialect"
+  ];
+}
+
 def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
   let summary = "Bufferize the `tensor` dialect";
-  let constructor = "mlir::createTensorBufferizePass()";
+  let constructor = "mlir::tensor::createTensorBufferizePass()";
 }
 
 #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
index 4cdf360..c0c46e9 100644 (file)
@@ -18,11 +18,9 @@ struct TilingResult;
 
 namespace tensor {
 
-/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
-/// to separate the cases where we don't need padding (all pad sizes are
-/// actually zeros) and where we indeed need padding.
-void populateSplitPaddingPatterns(RewritePatternSet &patterns,
-                                  PatternBenefit baseBenefit = 1);
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
 
 /// Pattern to swap an `tensor.extract_slice` with its producer when the
 /// producer implements the `TilingInterface`. The pattern itself does not
@@ -32,6 +30,23 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
+//===----------------------------------------------------------------------===//
+// Populate functions.
+//===----------------------------------------------------------------------===//
+
+/// Collects a set of patterns to rewrite ops within the tensor dialect.
+void populateExpandOpsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for folding tensor aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
+
+/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
+/// to separate the cases where we don't need padding (all pad sizes are
+/// actually zeros) and where we indeed need padding.
+void populateSplitPaddingPatterns(RewritePatternSet &patterns,
+                                  PatternBenefit baseBenefit = 1);
+
 /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
 /// into one. These patterns are in in this separate entry point because the
 /// bufferization is sensitive over IR structure, particularly those
index cc7c794..75a268c 100644 (file)
@@ -249,11 +249,11 @@ public:
 
   /// Returns a new AffineMap with the same number of dims and symbols and one
   /// less result at `pos`, dropped.
-  AffineMap dropResult(int64_t pos) { return dropResults({pos}); }
+  AffineMap dropResult(int64_t pos) const { return dropResults({pos}); }
 
   // Returns a new AffineMap with the same number of dims and symbols, but all
-  // positions in `positions` dropped from results.
-  AffineMap dropResults(ArrayRef<int64_t> positions) {
+  // results in `positions` dropped.
+  AffineMap dropResults(ArrayRef<int64_t> positions) const {
     SmallVector<int64_t> reverse_sorted_positions = llvm::to_vector(positions);
     llvm::sort(reverse_sorted_positions, std::greater<int64_t>());
 
@@ -263,9 +263,13 @@ public:
     return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
   }
 
+  // Returns a new AffineMap with the same number of dims and symbols, but all
+  // results in `positions` dropped.
+  AffineMap dropResults(const llvm::SmallBitVector &positions) const;
+
   /// Returns a new AffineMap with the same number of dims and symbols and an
   /// extra result inserted at `pos`.
-  AffineMap insertResult(AffineExpr expr, unsigned pos) {
+  AffineMap insertResult(AffineExpr expr, unsigned pos) const {
     auto exprs = llvm::to_vector<4>(getResults());
     exprs.insert(exprs.begin() + pos, expr);
     return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
@@ -583,6 +587,12 @@ llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
 // by any of the maps in the input array `maps`.
 llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps);
 
+/// Expand `map` to operate on `rank` dims while projecting out the dims in
+/// `projectedDimensions`. This amounts to composing `map` with
+/// `id(rank).dropResults(projectedDimensions)`.
+AffineMap expandDimsToRank(AffineMap map, int64_t rank,
+                           const llvm::SmallBitVector &projectedDimensions);
+
 inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
   map.print(os);
   return os;
index c506239..f53edce 100644 (file)
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 
@@ -74,3 +76,33 @@ LogicalResult mlir::mergeOffsetsSizesAndStrides(
       droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
       combinedOffsets, combinedSizes, combinedStrides);
 }
+
+void mlir::resolveSourceIndicesOffsetsAndStrides(
+    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedStrides,
+    const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
+    SmallVectorImpl<Value> &sourceIndices) {
+  OpFoldResult zero = rewriter.getIndexAttr(0);
+
+  // For each dimension that is rank-reduced, add a zero to the indices.
+  int64_t indicesDim = 0;
+  SmallVector<OpFoldResult> indices;
+  for (auto dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
+    OpFoldResult ofr =
+        (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++];
+    indices.push_back(ofr);
+  }
+
+  sourceIndices.resize(indices.size());
+  sourceIndices.clear();
+  for (auto [offset, index, stride] :
+       llvm::zip_equal(mixedOffsets, indices, mixedStrides)) {
+    AffineExpr off, idx, str;
+    bindSymbols(rewriter.getContext(), off, idx, str);
+    OpFoldResult ofr = makeComposedFoldedAffineApply(
+        rewriter, loc, AffineMap::get(0, 3, off + idx * str),
+        {offset, index, stride});
+    sourceIndices.push_back(
+        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+}
index c1c3478..c850348 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -19,7 +20,9 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -150,70 +153,6 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
   return success();
 }
 
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-///          memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
-                            memref::SubViewOp subViewOp, ValueRange indices,
-                            SmallVectorImpl<Value> &sourceIndices) {
-  SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
-  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
-  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
-
-  SmallVector<Value> useIndices;
-  // Check if this is rank-reducing case. Then for every unit-dim size add a
-  // zero to the indices.
-  int64_t resultDim = 0;
-  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
-  for (auto dim : llvm::seq<int64_t>(0, subViewOp.getSourceType().getRank())) {
-    if (unusedDims.test(dim))
-      useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
-    else
-      useIndices.push_back(indices[resultDim++]);
-  }
-  if (useIndices.size() != mixedOffsets.size())
-    return failure();
-  sourceIndices.resize(useIndices.size());
-  for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
-    SmallVector<OpFoldResult> dynamicOperands;
-    AffineExpr expr = rewriter.getAffineDimExpr(0);
-    int64_t numSymbols = 0;
-    dynamicOperands.push_back(useIndices[index]);
-
-    // Multiply the stride;
-    if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
-      expr = expr * attr.cast<IntegerAttr>().getInt();
-    } else {
-      dynamicOperands.push_back(mixedStrides[index].get<Value>());
-      expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
-    }
-
-    // Add the offset.
-    if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
-      expr = expr + attr.cast<IntegerAttr>().getInt();
-    } else {
-      dynamicOperands.push_back(mixedOffsets[index].get<Value>());
-      expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
-    }
-    Location loc = subViewOp.getLoc();
-    OpFoldResult ofr = makeComposedFoldedAffineApply(
-        rewriter, loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
-    sourceIndices[index] = getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
@@ -236,25 +175,6 @@ static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
   return op.getDstMemref();
 }
 
-/// Given the permutation map of the original
-/// `vector.transfer_read`/`vector.transfer_write` operations compute the
-/// permutation map to use after the subview is folded with it.
-static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
-                                           memref::SubViewOp subViewOp,
-                                           AffineMap currPermutationMap) {
-  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
-  SmallVector<AffineExpr> exprs;
-  int64_t sourceRank = subViewOp.getSourceType().getRank();
-  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
-    if (unusedDims.test(dim))
-      continue;
-    exprs.push_back(getAffineDimExpr(dim, context));
-  }
-  auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
-  return AffineMapAttr::get(
-      currPermutationMap.compose(resultDimToSourceDimMap));
-}
-
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
@@ -390,6 +310,42 @@ calculateExpandedAccessIndices(AffineMap affineMap,
   return expandedIndices;
 }
 
+template <typename XferOp>
+static LogicalResult
+preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
+                               memref::SubViewOp subviewOp) {
+  static_assert(
+      !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
+      "must be a vector transfer op");
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
+  if (xferOp.getMask())
+    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
+  if (!subviewOp.hasUnitStride()) {
+    return rewriter.notifyMatchFailure(
+        xferOp, "non-1 stride subview, need to track strides in folded memref");
+  }
+  return success();
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+                                                Operation *op,
+                                                memref::SubViewOp subviewOp) {
+  return success();
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+                                                vector::TransferReadOp readOp,
+                                                memref::SubViewOp subviewOp) {
+  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+                                                vector::TransferWriteOp writeOp,
+                                                memref::SubViewOp subviewOp) {
+  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
+}
+
 template <typename OpTy>
 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     OpTy loadOp, PatternRewriter &rewriter) const {
@@ -397,7 +353,12 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
 
   if (!subViewOp)
-    return failure();
+    return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
+
+  LogicalResult preconditionResult =
+      preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
+  if (failed(preconditionResult))
+    return preconditionResult;
 
   SmallVector<Value> indices(loadOp.getIndices().begin(),
                              loadOp.getIndices().end());
@@ -410,9 +371,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
-  if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
-                                         indices, sourceIndices)))
-    return failure();
+  resolveSourceIndicesOffsetsAndStrides(
+      rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
+      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+      sourceIndices);
 
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](AffineLoadOp op) {
@@ -423,14 +385,13 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
       })
-      .Case([&](vector::TransferReadOp transferReadOp) {
+      .Case([&](vector::TransferReadOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-            transferReadOp, transferReadOp.getVectorType(),
-            subViewOp.getSource(), sourceIndices,
-            getPermutationMapAttr(rewriter.getContext(), subViewOp,
-                                  transferReadOp.getPermutationMap()),
-            transferReadOp.getPadding(),
-            /*mask=*/Value(), transferReadOp.getInBoundsAttr());
+            op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
+            AffineMapAttr::get(expandDimsToRank(
+                op.getPermutationMap(), subViewOp.getSourceType().getRank(),
+                subViewOp.getDroppedDims())),
+            op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -512,7 +473,12 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
 
   if (!subViewOp)
-    return failure();
+    return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
+
+  LogicalResult preconditionResult =
+      preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
+  if (failed(preconditionResult))
+    return preconditionResult;
 
   SmallVector<Value> indices(storeOp.getIndices().begin(),
                              storeOp.getIndices().end());
@@ -525,9 +491,10 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
-  if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
-                                         indices, sourceIndices)))
-    return failure();
+  resolveSourceIndicesOffsetsAndStrides(
+      rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
+      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+      sourceIndices);
 
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](AffineStoreOp op) {
@@ -542,8 +509,9 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
       .Case([&](vector::TransferWriteOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
             op, op.getValue(), subViewOp.getSource(), sourceIndices,
-            getPermutationMapAttr(rewriter.getContext(), subViewOp,
-                                  op.getPermutationMap()),
+            AffineMapAttr::get(expandDimsToRank(
+                op.getPermutationMap(), subViewOp.getSourceType().getRank(),
+                subViewOp.getDroppedDims())),
             op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
index 9d26e51..93db7da 100644 (file)
@@ -2396,6 +2396,26 @@ struct InsertSliceOpSourceCastInserter final
 };
 } // namespace
 
+llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
+  ArrayRef<int64_t> resultShape = getType().getShape();
+  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+  llvm::SmallBitVector droppedDims(mixedSizes.size());
+  unsigned shapePos = 0;
+  for (const auto &size : enumerate(mixedSizes)) {
+    std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
+    // If the size is not 1, or if the current matched dimension of the result
+    // is the same static shape as the size value (which is 1), then the
+    // dimension is preserved.
+    if (!sizeVal || *sizeVal != 1 ||
+        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
+      shapePos++;
+      continue;
+    }
+    droppedDims.set(size.index());
+  }
+  return droppedDims;
+}
+
 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
index 426b136..d27c457 100644 (file)
@@ -53,6 +53,6 @@ struct TensorBufferizePass
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
+std::unique_ptr<Pass> mlir::tensor::createTensorBufferizePass() {
   return std::make_unique<TensorBufferizePass>();
 }
index 5ed3d97..9f67807 100644 (file)
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldIntoPackAndUnpackPatterns.cpp
+  FoldTensorSubsetOps.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
   ReshapePatterns.cpp
   SplitPaddingPatterns.cpp
@@ -29,4 +30,5 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRTensorDialect
   MLIRTilingInterface
   MLIRTransforms
+  MLIRVectorDialect
 )
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
new file mode 100644 (file)
index 0000000..80ecb86
--- /dev/null
@@ -0,0 +1,173 @@
+//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Fold tensor subset ops with producer / consumers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace tensor {
+#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
+#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
+} // namespace tensor
+} // namespace mlir
+
+using namespace mlir;
+
+static Value getTensorOperand(vector::TransferReadOp op) {
+  return op.getSource();
+}
+
+static Value getTensorOperand(tensor::InsertSliceOp op) {
+  return op.getSource();
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Merge extract_slice operation with load/transferRead operation.
+class TransferReadOfExtractSliceOpFolder final
+    : public OpRewritePattern<vector::TransferReadOp> {
+public:
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merge insert_slice operation with store/transferWriteOp operation.
+class InsertSliceOfTransferWriteOpFolder final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+template <typename XferOp, typename ExtractOrInsertOp>
+static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
+    RewriterBase &rewriter, XferOp xferOp,
+    ExtractOrInsertOp extractOrInsertSliceOp) {
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
+  if (xferOp.getMask())
+    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
+  if (!extractOrInsertSliceOp.hasUnitStride()) {
+    return rewriter.notifyMatchFailure(
+        xferOp, "non-1 stride insert/extract, requires keeping track of "
+                "strides, this may result in needing to insert "
+                "vector.insert_strided_slice/extract_strided_slice ops");
+  }
+  return success();
+}
+
+LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
+    vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
+  auto extractSliceOp =
+      getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
+
+  LogicalResult preconditionResult =
+      preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
+                                                     extractSliceOp);
+  if (failed(preconditionResult))
+    return preconditionResult;
+
+  SmallVector<Value> indices(readOp.getIndices().begin(),
+                             readOp.getIndices().end());
+  SmallVector<Value> sourceIndices;
+  resolveSourceIndicesOffsetsAndStrides(
+      rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
+      extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
+      indices, sourceIndices);
+
+  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+      readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
+      AffineMapAttr::get(expandDimsToRank(
+          readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
+          extractSliceOp.getDroppedDims())),
+      readOp.getPadding(),
+      /*mask=*/Value(), readOp.getInBoundsAttr());
+
+  return success();
+}
+
+LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
+    tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+  auto writeOp = getTensorOperand(insertSliceOp)
+                     .template getDefiningOp<vector::TransferWriteOp>();
+  if (!writeOp)
+    return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
+
+  LogicalResult preconditionResult =
+      preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
+                                                     insertSliceOp);
+  if (failed(preconditionResult))
+    return preconditionResult;
+
+  SmallVector<Value> indices(writeOp.getIndices().begin(),
+                             writeOp.getIndices().end());
+  SmallVector<Value> sourceIndices;
+  resolveSourceIndicesOffsetsAndStrides(
+      rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
+      insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
+      sourceIndices);
+
+  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+      insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
+      AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
+                                          insertSliceOp.getDestType().getRank(),
+                                          insertSliceOp.getDroppedDims())),
+      writeOp.getInBoundsAttr());
+
+  return success();
+}
+
+void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
+  patterns.add<TransferReadOfExtractSliceOpFolder,
+               InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
+}
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldTensorSubsetOpsPass final
+    : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void FoldTensorSubsetOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  tensor::populateFoldTensorSubsetOpPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
+  return std::make_unique<FoldTensorSubsetOpsPass>();
+}
index 4169882..895d1b1 100644 (file)
@@ -18,6 +18,7 @@ using namespace mlir::tensor;
 
 namespace {
 /// Merges consecutive tensor.extract_slice ops into one.
+// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -41,6 +42,7 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
 };
 
 /// Merges consecutive tensor.insert_slice ops into one.
+// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
 template <typename OpTy>
 struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
index efc7842..b7848b1 100644 (file)
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTensorUtils
   LINK_LIBS PUBLIC
   MLIRAffineDialect
   MLIRArithDialect
+  MLIRArithUtils
   MLIRIR
   MLIRTensorDialect
 )
index a584725..4c09c54 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 
 using namespace mlir;
index 21daff6..ce7d184 100644 (file)
@@ -3733,6 +3733,8 @@ namespace {
 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
 ///     : tensor<?x?xf32>, vector<4x5xf32>
 /// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
 struct FoldExtractSliceIntoTransferRead
     : public OpRewritePattern<TransferReadOp> {
 public:
@@ -3883,9 +3885,13 @@ struct TransferReadAfterWriteToBroadcast
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results
-      .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
-          context);
+  // clang-format off
+  results.add <
+               // TODO: this is brittle and should be deprecated in favor of a
+               // more general pattern that applies on-demand.
+               FoldExtractSliceIntoTransferRead,
+               TransferReadAfterWriteToBroadcast>(context);
+  // clang-format on
 }
 
 //===----------------------------------------------------------------------===//
@@ -4235,6 +4241,8 @@ public:
 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
 ///     : vector<4x5xf32>, tensor<?x?xf32>
 /// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
 struct FoldInsertSliceIntoTransferWrite
     : public OpRewritePattern<tensor::InsertSliceOp> {
 public:
@@ -4417,8 +4425,13 @@ public:
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
+  // clang-format off
+  results.add<FoldWaw,
+              // TODO: this is brittle and should be deprecated in favor of a
+              // more general pattern that applies on-demand.
+              FoldInsertSliceIntoTransferWrite,
               SwapExtractSliceOfTransferWrite>(context);
+  // clang-format on
 }
 
 //===----------------------------------------------------------------------===//
index 39c8ab9..9ac181f 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "AffineMapDetail.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
+#include <iterator>
 #include <numeric>
 #include <optional>
 #include <type_traits>
@@ -467,6 +470,15 @@ AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
   return AffineMap::inferFromExprList(newResults).front();
 }
 
+AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
+  auto exprs = llvm::to_vector<4>(getResults());
+  // TODO: this is a pretty terrible API .. is there anything better?
+  for (auto pos = positions.find_last(); pos != -1;
+       pos = positions.find_prev(pos))
+    exprs.erase(exprs.begin() + pos);
+  return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+}
+
 AffineMap AffineMap::compose(AffineMap map) const {
   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
   // Prepare `map` by concatenating the symbols and rewriting its exprs.
@@ -808,6 +820,14 @@ llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
   return numSymbolsBitVector;
 }
 
+AffineMap
+mlir::expandDimsToRank(AffineMap map, int64_t rank,
+                       const llvm::SmallBitVector &projectedDimensions) {
+  auto id = AffineMap::getMultiDimIdentityMap(rank, map.getContext());
+  AffineMap proj = id.dropResults(projectedDimensions);
+  return map.compose(proj);
+}
+
 //===----------------------------------------------------------------------===//
 // MutableAffineMap.
 //===----------------------------------------------------------------------===//
index bcbad20..a29f86e 100644 (file)
@@ -6,7 +6,7 @@ func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1
   return %1 : f32
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0  + s1 * 3)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
 //      CHECK: func @fold_static_stride_subview_with_load
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
@@ -25,7 +25,7 @@ func.func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg
   %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
   return %1 : f32
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
 //      CHECK: func @fold_dynamic_stride_subview_with_load
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
@@ -34,8 +34,8 @@ func.func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
 //      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]]]
 
 // -----
@@ -66,7 +66,7 @@ func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %ar
   memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
   return
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
 //      CHECK: func @fold_dynamic_stride_subview_with_store
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
@@ -75,8 +75,8 @@ func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %ar
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
 //      CHECK:   memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
 
 // -----
@@ -85,7 +85,7 @@ func.func @fold_subview_with_transfer_read_0d(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
     -> vector<f32> {
   %f1 = arith.constant 1.0 : f32
-  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.transfer_read %0[], %f1 : memref<f32, strided<[], offset: ?>>, vector<f32>
   return %1 : vector<f32>
 }
@@ -100,22 +100,14 @@ func.func @fold_subview_with_transfer_read_0d(
 
 func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
   %f1 = arith.constant 1.0 : f32
+
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
   %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
   return %1 : vector<4xf32>
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
 //      CHECK: func @fold_subview_with_transfer_read
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
-//      CHECK:   vector.transfer_read %[[ARG0]][%[[I1]], %[[I2]]]
+// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
+//   CHECK: memref.subview
 
 // -----
 
@@ -123,7 +115,7 @@ func.func @fold_static_stride_subview_with_transfer_write_0d(
     %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index,
     %v : vector<f32>) {
   %f1 = arith.constant 1.0 : f32
-  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.transfer_write %v, %0[] {in_bounds = []} : vector<f32>, memref<f32, strided<[], offset: ?>>
   return
 }
@@ -143,18 +135,9 @@ func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf
   vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>
   return
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
 //      CHECK: func @fold_static_stride_subview_with_transfer_write
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
-//      CHECK:   vector.transfer_write %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
+//   CHECK: memref.subview
 
 // -----
 
@@ -168,7 +151,7 @@ func.func @fold_rank_reducing_subview_with_load
   %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>>
   return %1 : f32
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
 //      CHECK: func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
@@ -187,10 +170,10 @@ func.func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG14:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG15:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG7]], %[[ARG1]], %[[ARG13]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG9]], %[[ARG3]], %[[ARG14]]]
-//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG10]], %[[ARG4]], %[[ARG15]]]
-//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG11]], %[[ARG5]], %[[ARG16]]]
+//  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
+//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
 //      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
 
 // -----
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
new file mode 100644 (file)
index 0000000..93a0d77
--- /dev/null
@@ -0,0 +1,262 @@
+// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s
+
+func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice(
+    %arg0 : tensor<?x?x?xf32>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index) -> vector<4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1]
+      : tensor<?x?x?xf32> to
+        tensor<?x?xf32>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]}
+      : tensor<?x?xf32>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_vector_transfer_read_with_rank_reduced_extract_slice
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[$MAP1]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]], %{{.*}} : tensor<?x?x?xf32
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_read_from_rank_reducing_extract_slice_failure
+func.func @transfer_read_from_rank_reducing_extract_slice_failure(
+    %src: tensor<1x8x8x8xf32>,
+    %i1: index, %i2: index, %i3: index, %i4: index) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %f0 = arith.constant 0.000000e+00 : f32
+
+  // Can't fold this atm since we don' emit the proper vector.extract_strided_slice.
+//   CHECK: tensor.extract_slice
+  %0 = tensor.extract_slice %src[0, %i1, %i2, %i3] [1, 4, 1, 4] [2, 3, 4, 5] : tensor<1x8x8x8xf32> to tensor<1x4x4xf32>
+  %1 = vector.transfer_read %0[%c1, %i4, %c2], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+// -----
+
+func.func @fold_extract_slice_with_transfer_read_0d(
+  %arg0 : tensor<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> vector<f32> {
+  %f1 = arith.constant 1.0 : f32
+  %0 = tensor.extract_slice %arg0[%arg1, %arg2][1, 1][1, 1] : tensor<12x32xf32> to tensor<f32>
+  %1 = vector.transfer_read %0[], %f1 : tensor<f32>, vector<f32>
+  return %1 : vector<f32>
+}
+//      CHECK: func @fold_extract_slice_with_transfer_read_0d
+// CHECK-SAME:   %[[T:[a-zA-Z0-9_]+]]: tensor<12x32xf32>
+// CHECK-SAME:   %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ST1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   vector.transfer_read %[[T]][%[[SZ0]], %[[SZ1]]]
+
+// -----
+
+//   CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+  return %1 : vector<6xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[$ADD_3:.+]] = affine_map<()[s0] -> (s0 + 3)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$ADD_3]]()[%[[s1]]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+//   CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_swappy_rank_reducing(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+
+//   CHECK-NOT:   extract_slice
+//       CHECK:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s2]]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[s1]], %[[add]]]
+//  CHECK-SAME:     permutation_map = #[[$d0d2]]
+//  CHECK-SAME:     tensor<?x?x?xf32>, vector<5x6xf32>
+  %0 = tensor.extract_slice %t[5, %s1, %s2] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+//       CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
+    %arg0 : tensor<?x?x?xf32>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index,
+    %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+
+//   CHECK-NOT:    insert_slice
+//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?x?xf32
+  %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
+      : vector<4xf32>, tensor<?x?xf32>
+  %1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+//       CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice(
+    %arg0 : tensor<?x?x?xf32>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index,
+    %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+
+  //   CHECK-NOT: insert_slice
+  //   CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
+  //   CHECK-DAG:  %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+  //   CHECK-DAG:  %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+  //   CHECK-DAG:  vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
+  //  CHECK-SAME:    {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor<?x?x?xf32
+  %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
+      : vector<4xf32>, tensor<?x?xf32>
+  %1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+  %c0 = arith.constant 0 : index
+
+  //   CHECK-NOT: insert_slice
+//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+//       CHECK:   return %[[r]]
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+  return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+func.func @insert_slice_of_transfer_write_swappy_rank_extending(
+    %t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, 
+    %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+
+//   CHECK-NOT:   insert_slice
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]]
+//  CHECK-SAME:    {in_bounds = [true, true], permutation_map = #[[$d0d2]]} : vector<5x6xf32>, tensor<?x?x12xf32>
+//       CHECK:   return %[[r]]
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+//       CHECK:   return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}
index 4071d92..8538c3d 100644 (file)
@@ -5607,6 +5607,7 @@ cc_library(
     deps = [
         ":AffineDialect",
         ":ArithDialect",
+        ":ArithUtils",
         ":DialectUtils",
         ":TensorDialect",
         "//llvm:Support",
@@ -5663,6 +5664,7 @@ cc_library(
         ":TensorPassIncGen",
         ":TilingInterface",
         ":Transforms",
+        ":VectorDialect",
         "//llvm:Support",
     ],
 )