These patterns follow FoldMemRefAliasOps which is further refactored for reuse.
In the process, fix FoldMemRefAliasOps handling of strides for vector.transfer ops which was previously incorrect.
These opt-in patterns generalize the existing canonicalizations on vector.transfer ops.
In the future the blanket canonicalizations will be retired.
They are kept for now to minimize porting disruptions.
Differential Revision: https://reviews.llvm.org/D146624
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
+class RewriterBase;
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
/// when combining a producer slice **into** a consumer slice.
/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
/// - Combined sizes = consumer_sizes
/// - Combined strides = producer_strides * consumer_strides
+// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
LogicalResult
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> producerOffsets,
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
/// when combining a `producer` slice op **into** a `consumer` slice op.
+// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
LogicalResult
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
OffsetSizeAndStrideOpInterface producer,
SmallVector<OpFoldResult> &combinedSizes,
SmallVector<OpFoldResult> &combinedStrides);
+/// Given the 'indicesVals' of a load/store operation operating on an op with
+/// offsets and strides, return the combined indices.
+///
+/// For example, using `memref.load` and `memref.subview` as an illustration:
+///
+/// ```
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.subview %0[%arg0, %arg1][...][%stride1, %stride2] :
+/// memref<12x42xf32> to memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+/// ```
+///
+/// could be folded into:
+///
+/// ```
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+/// ```
+void resolveSourceIndicesOffsetsAndStrides(
+ RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedStrides,
+ const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
+ SmallVectorImpl<Value> &sourceIndices);
+
} // namespace mlir
#endif // MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
return {rank, rank, rank};
}
+ /// Return the dimensions of the dest that are omitted to insert a source
+ /// when the result is rank-extended.
+ llvm::SmallBitVector getDroppedDims();
+
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
#include "mlir/Pass/Pass.h"
namespace mlir {
+namespace tensor {
-#define GEN_PASS_DECL
-#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
-/// Creates an instance of `tensor` dialect bufferization pass.
+/// Creates an instance of the `tensor` subset folding pass.
+std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
+
+/// Creates an instance of the `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
-namespace tensor {
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
-} // namespace tensor
+} // namespace tensor
} // namespace mlir
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
include "mlir/Pass/PassBase.td"
+def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
+ let summary = "Fold tensor subset ops into producer/consumer ops";
+ let description = [{
+ The pass folds tensor subset ops into producer/consumer ops.
+
+ At the moment, the following foldings occur when possible:
+ - tensor.extract_slice into vector.transfer_read
+ - vector.transfer_write into tensor.insert_slice
+
+ }];
+ let constructor = "mlir::tensor::createFoldTensorSubsetOpsPass()";
+ let dependentDialects = [
+ "AffineDialect", "tensor::TensorDialect", "vector::VectorDialect"
+ ];
+}
+
def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
let summary = "Bufferize the `tensor` dialect";
- let constructor = "mlir::createTensorBufferizePass()";
+ let constructor = "mlir::tensor::createTensorBufferizePass()";
}
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
namespace tensor {
-/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
-/// to separate the cases where we don't need padding (all pad sizes are
-/// actually zeros) and where we indeed need padding.
-void populateSplitPaddingPatterns(RewritePatternSet &patterns,
- PatternBenefit baseBenefit = 1);
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
/// Pattern to swap an `tensor.extract_slice` with its producer when the
/// producer implements the `TilingInterface`. The pattern itself does not
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
+//===----------------------------------------------------------------------===//
+// Populate functions.
+//===----------------------------------------------------------------------===//
+
+/// Collects a set of patterns to rewrite ops within the tensor dialect.
+void populateExpandOpsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for folding tensor aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
+
+/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
+/// to separate the cases where we don't need padding (all pad sizes are
+/// actually zeros) and where we indeed need padding.
+void populateSplitPaddingPatterns(RewritePatternSet &patterns,
+ PatternBenefit baseBenefit = 1);
+
/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
/// into one. These patterns are in in this separate entry point because the
/// bufferization is sensitive over IR structure, particularly those
/// Returns a new AffineMap with the same number of dims and symbols and one
/// less result at `pos`, dropped.
- AffineMap dropResult(int64_t pos) { return dropResults({pos}); }
+ AffineMap dropResult(int64_t pos) const { return dropResults({pos}); }
// Returns a new AffineMap with the same number of dims and symbols, but all
- // positions in `positions` dropped from results.
- AffineMap dropResults(ArrayRef<int64_t> positions) {
+ // results in `positions` dropped.
+ AffineMap dropResults(ArrayRef<int64_t> positions) const {
SmallVector<int64_t> reverse_sorted_positions = llvm::to_vector(positions);
llvm::sort(reverse_sorted_positions, std::greater<int64_t>());
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
}
+ // Returns a new AffineMap with the same number of dims and symbols, but all
+ // results in `positions` dropped.
+ AffineMap dropResults(const llvm::SmallBitVector &positions) const;
+
/// Returns a new AffineMap with the same number of dims and symbols and an
/// extra result inserted at `pos`.
- AffineMap insertResult(AffineExpr expr, unsigned pos) {
+ AffineMap insertResult(AffineExpr expr, unsigned pos) const {
auto exprs = llvm::to_vector<4>(getResults());
exprs.insert(exprs.begin() + pos, expr);
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
// by any of the maps in the input array `maps`.
llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps);
+/// Expand `map` to operate on `rank` dims while projecting out the dims in
+/// `projectedDimensions`. This amounts to composing `map` with
+/// `id(rank).dropResults(projectedDimensions)`.
+AffineMap expandDimsToRank(AffineMap map, int64_t rank,
+ const llvm::SmallBitVector &projectedDimensions);
+
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
combinedOffsets, combinedSizes, combinedStrides);
}
+
+void mlir::resolveSourceIndicesOffsetsAndStrides(
+ RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedStrides,
+ const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
+ SmallVectorImpl<Value> &sourceIndices) {
+ OpFoldResult zero = rewriter.getIndexAttr(0);
+
+ // For each dimension that is rank-reduced, add a zero to the indices.
+ int64_t indicesDim = 0;
+ SmallVector<OpFoldResult> indices;
+ for (auto dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
+ OpFoldResult ofr =
+ (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++];
+ indices.push_back(ofr);
+ }
+
+ sourceIndices.resize(indices.size());
+ sourceIndices.clear();
+ for (auto [offset, index, stride] :
+ llvm::zip_equal(mixedOffsets, indices, mixedStrides)) {
+ AffineExpr off, idx, str;
+ bindSymbols(rewriter.getContext(), off, idx, str);
+ OpFoldResult ofr = makeComposedFoldedAffineApply(
+ rewriter, loc, AffineMap::get(0, 3, off + idx * str),
+ {offset, index, stride});
+ sourceIndices.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+}
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
return success();
}
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-/// memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
- memref::SubViewOp subViewOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
-
- SmallVector<Value> useIndices;
- // Check if this is rank-reducing case. Then for every unit-dim size add a
- // zero to the indices.
- int64_t resultDim = 0;
- llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
- for (auto dim : llvm::seq<int64_t>(0, subViewOp.getSourceType().getRank())) {
- if (unusedDims.test(dim))
- useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
- else
- useIndices.push_back(indices[resultDim++]);
- }
- if (useIndices.size() != mixedOffsets.size())
- return failure();
- sourceIndices.resize(useIndices.size());
- for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
- SmallVector<OpFoldResult> dynamicOperands;
- AffineExpr expr = rewriter.getAffineDimExpr(0);
- int64_t numSymbols = 0;
- dynamicOperands.push_back(useIndices[index]);
-
- // Multiply the stride;
- if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
- expr = expr * attr.cast<IntegerAttr>().getInt();
- } else {
- dynamicOperands.push_back(mixedStrides[index].get<Value>());
- expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
- }
-
- // Add the offset.
- if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
- expr = expr + attr.cast<IntegerAttr>().getInt();
- } else {
- dynamicOperands.push_back(mixedOffsets[index].get<Value>());
- expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
- }
- Location loc = subViewOp.getLoc();
- OpFoldResult ofr = makeComposedFoldedAffineApply(
- rewriter, loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
- sourceIndices[index] = getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- }
- return success();
-}
-
/// Helpers to access the memref operand for each op.
template <typename LoadOrStoreOpTy>
static Value getMemRefOperand(LoadOrStoreOpTy op) {
return op.getDstMemref();
}
-/// Given the permutation map of the original
-/// `vector.transfer_read`/`vector.transfer_write` operations compute the
-/// permutation map to use after the subview is folded with it.
-static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
- memref::SubViewOp subViewOp,
- AffineMap currPermutationMap) {
- llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
- SmallVector<AffineExpr> exprs;
- int64_t sourceRank = subViewOp.getSourceType().getRank();
- for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
- if (unusedDims.test(dim))
- continue;
- exprs.push_back(getAffineDimExpr(dim, context));
- }
- auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
- return AffineMapAttr::get(
- currPermutationMap.compose(resultDimToSourceDimMap));
-}
-
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
return expandedIndices;
}
+template <typename XferOp>
+static LogicalResult
+preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
+ memref::SubViewOp subviewOp) {
+ static_assert(
+ !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
+ "must be a vector transfer op");
+ if (xferOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
+ if (xferOp.getMask())
+ return rewriter.notifyMatchFailure(xferOp, "masked transfer");
+ if (!subviewOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(
+ xferOp, "non-1 stride subview, need to track strides in folded memref");
+ }
+ return success();
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+ Operation *op,
+ memref::SubViewOp subviewOp) {
+ return success();
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+ vector::TransferReadOp readOp,
+ memref::SubViewOp subviewOp) {
+ return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
+}
+
+static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
+ vector::TransferWriteOp writeOp,
+ memref::SubViewOp subviewOp) {
+ return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
+}
+
template <typename OpTy>
LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
- return failure();
+ return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
+
+ LogicalResult preconditionResult =
+ preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
+ if (failed(preconditionResult))
+ return preconditionResult;
SmallVector<Value> indices(loadOp.getIndices().begin(),
loadOp.getIndices().end());
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
- if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
- indices, sourceIndices)))
- return failure();
+ resolveSourceIndicesOffsetsAndStrides(
+ rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+ sourceIndices);
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](AffineLoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
})
- .Case([&](vector::TransferReadOp transferReadOp) {
+ .Case([&](vector::TransferReadOp op) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- transferReadOp, transferReadOp.getVectorType(),
- subViewOp.getSource(), sourceIndices,
- getPermutationMapAttr(rewriter.getContext(), subViewOp,
- transferReadOp.getPermutationMap()),
- transferReadOp.getPadding(),
- /*mask=*/Value(), transferReadOp.getInBoundsAttr());
+ op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
+ AffineMapAttr::get(expandDimsToRank(
+ op.getPermutationMap(), subViewOp.getSourceType().getRank(),
+ subViewOp.getDroppedDims())),
+ op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
- return failure();
+ return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
+
+ LogicalResult preconditionResult =
+ preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
+ if (failed(preconditionResult))
+ return preconditionResult;
SmallVector<Value> indices(storeOp.getIndices().begin(),
storeOp.getIndices().end());
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
- if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
- indices, sourceIndices)))
- return failure();
+ resolveSourceIndicesOffsetsAndStrides(
+ rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+ sourceIndices);
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](AffineStoreOp op) {
.Case([&](vector::TransferWriteOp op) {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, op.getValue(), subViewOp.getSource(), sourceIndices,
- getPermutationMapAttr(rewriter.getContext(), subViewOp,
- op.getPermutationMap()),
+ AffineMapAttr::get(expandDimsToRank(
+ op.getPermutationMap(), subViewOp.getSourceType().getRank(),
+ subViewOp.getDroppedDims())),
op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
};
} // namespace
+llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
+ ArrayRef<int64_t> resultShape = getType().getShape();
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+ llvm::SmallBitVector droppedDims(mixedSizes.size());
+ unsigned shapePos = 0;
+ for (const auto &size : enumerate(mixedSizes)) {
+ std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
+ // If the size is not 1, or if the current matched dimension of the result
+ // is the same static shape as the size value (which is 1), then the
+ // dimension is preserved.
+ if (!sizeVal || *sizeVal != 1 ||
+ (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
+ shapePos++;
+ continue;
+ }
+ droppedDims.set(size.index());
+ }
+ return droppedDims;
+}
+
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
};
} // namespace
-std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
+std::unique_ptr<Pass> mlir::tensor::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
FoldIntoPackAndUnpackPatterns.cpp
+ FoldTensorSubsetOps.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
ReshapePatterns.cpp
SplitPaddingPatterns.cpp
MLIRTensorDialect
MLIRTilingInterface
MLIRTransforms
+ MLIRVectorDialect
)
--- /dev/null
+//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Fold tensor subset ops with producer / consumers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace tensor {
+#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
+#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
+} // namespace tensor
+} // namespace mlir
+
+using namespace mlir;
+
+static Value getTensorOperand(vector::TransferReadOp op) {
+ return op.getSource();
+}
+
+static Value getTensorOperand(tensor::InsertSliceOp op) {
+ return op.getSource();
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Merge extract_slice operation with load/transferRead operation.
+class TransferReadOfExtractSliceOpFolder final
+ : public OpRewritePattern<vector::TransferReadOp> {
+public:
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merge insert_slice operation with store/transferWriteOp operation.
+class InsertSliceOfTransferWriteOpFolder final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+template <typename XferOp, typename ExtractOrInsertOp>
+static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
+ RewriterBase &rewriter, XferOp xferOp,
+ ExtractOrInsertOp extractOrInsertSliceOp) {
+ if (xferOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
+ if (xferOp.getMask())
+ return rewriter.notifyMatchFailure(xferOp, "masked transfer");
+ if (!extractOrInsertSliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(
+ xferOp, "non-1 stride insert/extract, requires keeping track of "
+ "strides, this may result in needing to insert "
+ "vector.insert_strided_slice/extract_strided_slice ops");
+ }
+ return success();
+}
+
+LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
+ vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
+ auto extractSliceOp =
+ getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
+
+ LogicalResult preconditionResult =
+ preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
+ extractSliceOp);
+ if (failed(preconditionResult))
+ return preconditionResult;
+
+ SmallVector<Value> indices(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ SmallVector<Value> sourceIndices;
+ resolveSourceIndicesOffsetsAndStrides(
+ rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
+ extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
+ indices, sourceIndices);
+
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
+ AffineMapAttr::get(expandDimsToRank(
+ readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
+ extractSliceOp.getDroppedDims())),
+ readOp.getPadding(),
+ /*mask=*/Value(), readOp.getInBoundsAttr());
+
+ return success();
+}
+
+LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
+ tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+ auto writeOp = getTensorOperand(insertSliceOp)
+ .template getDefiningOp<vector::TransferWriteOp>();
+ if (!writeOp)
+ return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
+
+ LogicalResult preconditionResult =
+ preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
+ insertSliceOp);
+ if (failed(preconditionResult))
+ return preconditionResult;
+
+ SmallVector<Value> indices(writeOp.getIndices().begin(),
+ writeOp.getIndices().end());
+ SmallVector<Value> sourceIndices;
+ resolveSourceIndicesOffsetsAndStrides(
+ rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
+ sourceIndices);
+
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
+ AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
+ insertSliceOp.getDestType().getRank(),
+ insertSliceOp.getDroppedDims())),
+ writeOp.getInBoundsAttr());
+
+ return success();
+}
+
+void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
+ patterns.add<TransferReadOfExtractSliceOpFolder,
+ InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
+}
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct FoldTensorSubsetOpsPass final
+ : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void FoldTensorSubsetOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ tensor::populateFoldTensorSubsetOpPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
+ return std::make_unique<FoldTensorSubsetOpsPass>();
+}
namespace {
/// Merges consecutive tensor.extract_slice ops into one.
+// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
};
/// Merges consecutive tensor.insert_slice ops into one.
+// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
template <typename OpTy>
struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
+ MLIRArithUtils
MLIRIR
MLIRTensorDialect
)
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
using namespace mlir;
/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
/// : tensor<?x?xf32>, vector<4x5xf32>
/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
struct FoldExtractSliceIntoTransferRead
: public OpRewritePattern<TransferReadOp> {
public:
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
- context);
+ // clang-format off
+ results.add <
+ // TODO: this is brittle and should be deprecated in favor of a
+ // more general pattern that applies on-demand.
+ FoldExtractSliceIntoTransferRead,
+ TransferReadAfterWriteToBroadcast>(context);
+ // clang-format on
}
//===----------------------------------------------------------------------===//
/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
/// : vector<4x5xf32>, tensor<?x?xf32>
/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
struct FoldInsertSliceIntoTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
+ // clang-format off
+ results.add<FoldWaw,
+ // TODO: this is brittle and should be deprecated in favor of a
+ // more general pattern that applies on-demand.
+ FoldInsertSliceIntoTransferWrite,
SwapExtractSliceOfTransferWrite>(context);
+ // clang-format on
}
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineMap.h"
#include "AffineMapDetail.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
+#include <iterator>
#include <numeric>
#include <optional>
#include <type_traits>
return AffineMap::inferFromExprList(newResults).front();
}
+AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
+ auto exprs = llvm::to_vector<4>(getResults());
+ // TODO: this is a pretty terrible API .. is there anything better?
+ for (auto pos = positions.find_last(); pos != -1;
+ pos = positions.find_prev(pos))
+ exprs.erase(exprs.begin() + pos);
+ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+}
+
AffineMap AffineMap::compose(AffineMap map) const {
assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
// Prepare `map` by concatenating the symbols and rewriting its exprs.
return numSymbolsBitVector;
}
+AffineMap
+mlir::expandDimsToRank(AffineMap map, int64_t rank,
+ const llvm::SmallBitVector &projectedDimensions) {
+ auto id = AffineMap::getMultiDimIdentityMap(rank, map.getContext());
+ AffineMap proj = id.dropResults(projectedDimensions);
+ return map.compose(proj);
+}
+
//===----------------------------------------------------------------------===//
// MutableAffineMap.
//===----------------------------------------------------------------------===//
return %1 : f32
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
// CHECK: func @fold_static_stride_subview_with_load
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
%1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
return %1 : f32
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
// CHECK: func @fold_dynamic_stride_subview_with_load
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]]
// -----
memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
return
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
// CHECK: func @fold_dynamic_stride_subview_with_store
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
// CHECK: memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
// -----
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
-> vector<f32> {
%f1 = arith.constant 1.0 : f32
- %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.transfer_read %0[], %f1 : memref<f32, strided<[], offset: ?>>, vector<f32>
return %1 : vector<f32>
}
func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
%f1 = arith.constant 1.0 : f32
+
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
%1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
return %1 : vector<4xf32>
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
// CHECK: func @fold_subview_with_transfer_read
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
-// CHECK: vector.transfer_read %[[ARG0]][%[[I1]], %[[I2]]]
+// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
+// CHECK: memref.subview
// -----
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index,
%v : vector<f32>) {
%f1 = arith.constant 1.0 : f32
- %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.transfer_write %v, %0[] {in_bounds = []} : vector<f32>, memref<f32, strided<[], offset: ?>>
return
}
vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>
return
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
// CHECK: func @fold_static_stride_subview_with_transfer_write
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]]
-// CHECK: vector.transfer_write %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
+// CHECK: memref.subview
// -----
%1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>>
return %1 : f32
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
// CHECK: func @fold_rank_reducing_subview_with_load
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG14:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG15:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG7]], %[[ARG1]], %[[ARG13]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG9]], %[[ARG3]], %[[ARG14]]]
-// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG10]], %[[ARG4]], %[[ARG15]]]
-// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG11]], %[[ARG5]], %[[ARG16]]]
+// CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
+// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
+// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
// CHECK: memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
// -----
--- /dev/null
+// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s
+
+func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice(
+ %arg0 : tensor<?x?x?xf32>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index) -> vector<4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1]
+ : tensor<?x?x?xf32> to
+ tensor<?x?xf32>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]}
+ : tensor<?x?xf32>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_vector_transfer_read_with_rank_reduced_extract_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[$MAP1]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]], %{{.*}} : tensor<?x?x?xf32
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_read_from_rank_reducing_extract_slice_failure
+func.func @transfer_read_from_rank_reducing_extract_slice_failure(
+ %src: tensor<1x8x8x8xf32>,
+ %i1: index, %i2: index, %i3: index, %i4: index) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %f0 = arith.constant 0.000000e+00 : f32
+
+ // Can't fold this atm since we don' emit the proper vector.extract_strided_slice.
+// CHECK: tensor.extract_slice
+ %0 = tensor.extract_slice %src[0, %i1, %i2, %i3] [1, 4, 1, 4] [2, 3, 4, 5] : tensor<1x8x8x8xf32> to tensor<1x4x4xf32>
+ %1 = vector.transfer_read %0[%c1, %i4, %c2], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+// -----
+
+func.func @fold_extract_slice_with_transfer_read_0d(
+ %arg0 : tensor<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
+ -> vector<f32> {
+ %f1 = arith.constant 1.0 : f32
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2][1, 1][1, 1] : tensor<12x32xf32> to tensor<f32>
+ %1 = vector.transfer_read %0[], %f1 : tensor<f32>, vector<f32>
+ return %1 : vector<f32>
+}
+// CHECK: func @fold_extract_slice_with_transfer_read_0d
+// CHECK-SAME: %[[T:[a-zA-Z0-9_]+]]: tensor<12x32xf32>
+// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index
+// CHECK: vector.transfer_read %[[T]][%[[SZ0]], %[[SZ1]]]
+
+// -----
+
+// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+ return %1 : vector<6xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$ADD_3:.+]] = affine_map<()[s0] -> (s0 + 3)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: %[[add:.*]] = affine.apply #[[$ADD_3]]()[%[[s1]]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_swappy_rank_reducing(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+
+// CHECK-NOT: extract_slice
+// CHECK: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s2]]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[s1]], %[[add]]]
+// CHECK-SAME: permutation_map = #[[$d0d2]]
+// CHECK-SAME: tensor<?x?x?xf32>, vector<5x6xf32>
+ %0 = tensor.extract_slice %t[5, %s1, %s2] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
+ %arg0 : tensor<?x?x?xf32>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index,
+ %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+
+// CHECK-NOT: insert_slice
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?x?xf32
+ %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
+ : vector<4xf32>, tensor<?x?xf32>
+ %1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice(
+ %arg0 : tensor<?x?x?xf32>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index,
+ %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+
+ // CHECK-NOT: insert_slice
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+ // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
+ // CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor<?x?x?xf32
+ %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
+ : vector<4xf32>, tensor<?x?xf32>
+ %1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+ %c0 = arith.constant 0 : index
+
+ // CHECK-NOT: insert_slice
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+// CHECK: return %[[r]]
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+ return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+func.func @insert_slice_of_transfer_write_swappy_rank_extending(
+ %t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>,
+ %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+
+// CHECK-NOT: insert_slice
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]]
+// CHECK-SAME: {in_bounds = [true, true], permutation_map = #[[$d0d2]]} : vector<5x6xf32>, tensor<?x?x12xf32>
+// CHECK: return %[[r]]
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+// CHECK: return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
deps = [
":AffineDialect",
":ArithDialect",
+ ":ArithUtils",
":DialectUtils",
":TensorDialect",
"//llvm:Support",
":TensorPassIncGen",
":TilingInterface",
":Transforms",
+ ":VectorDialect",
"//llvm:Support",
],
)