#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
SideEffects::DefaultResource::get());
}
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+/// : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = addi %a, %e : index
+/// %p1 = addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+struct FoldExtractSliceIntoTransferRead
+ : public OpRewritePattern<TransferReadOp> {
+public:
+ using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransferReadOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (!xferOp.permutation_map().isIdentity())
+ return failure();
+ if (xferOp.mask())
+ return failure();
+ auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ return failure();
+ if (!extractOp.hasUnitStride())
+ return failure();
+
+ int64_t rankReduced =
+ extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+ SmallVector<Value> newIndices;
+ // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+ // indices first.
+ for (int64_t i = 0; i < rankReduced; ++i) {
+ OpFoldResult offset = extractOp.getMixedOffsets()[i];
+ newIndices.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, extractOp.getLoc(), offset));
+ }
+ for (auto it : llvm::enumerate(xferOp.indices())) {
+ OpFoldResult offset =
+ extractOp.getMixedOffsets()[it.index() + rankReduced];
+ newIndices.push_back(
+ rewriter.create<AddIOp>(xferOp->getLoc(), it.value(),
+ getValueOrCreateConstantIndexOp(
+ rewriter, extractOp.getLoc(), offset)));
+ }
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
+ extractOp.source(), newIndices,
+ xferOp.padding(), inBounds);
+
+ return success();
+ }
+};
+} // namespace
+
+void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldExtractSliceIntoTransferRead>(context);
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
return failure();
}
};
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+/// : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+struct FoldInsertSliceIntoTransferWrite
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ if (!insertOp.hasUnitStride())
+ return failure();
+ auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
+ if (!xferOp)
+ return failure();
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+ return failure();
+ if (xferOp.mask())
+ return failure();
+ // Fold only if the TransferWriteOp completely overwrites the `source` with
+ // a vector. I.e., the result of the TransferWriteOp is a new tensor who's
+ // content is the data of the vector.
+ if (!llvm::equal(xferOp.getVectorType().getShape(),
+ xferOp.getShapedType().getShape()))
+ return failure();
+ if (!xferOp.permutation_map().isIdentity())
+ return failure();
+
+ SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+ rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<TransferWriteOp>(
+ insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
+ return success();
+ }
+};
} // namespace
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<foldWAW>(context);
+ results.add<foldWAW, FoldInsertSliceIntoTransferWrite>(context);
}
//===----------------------------------------------------------------------===//
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c4:.*]] = constant 4 : index
+// CHECK-DAG: %[[c8:.*]] = constant 8 : index
+// CHECK: %[[add:.*]] = addi %[[s1]], %[[c4]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %cst = constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = constant 3 : index
+// CHECK-DAG: %[[c5:.*]] = constant 5 : index
+// CHECK-DAG: %[[c10:.*]] = constant 10 : index
+// CHECK: %[[add:.*]] = addi %[[s1]], %[[c3]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %cst = constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK: %[[c3:.*]] = constant 3 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+// CHECK: return %[[r]]
+func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+ %c0 = constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+ return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = constant 3 : index
+// CHECK-DAG: %[[c4:.*]] = constant 4 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+// CHECK: return %[[r]]
+func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}