From 27451a05ed4d13294182ec7e999a9d4f90bc0d12 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 30 Sep 2021 09:25:40 +0900 Subject: [PATCH] [mlir][vector] Fold transfer ops and tensor.extract/insert_slice. * Fold vector.transfer_read and tensor.extract_slice. * Fold vector.transfer_write and tensor.insert_slice. Differential Revision: https://reviews.llvm.org/D110627 --- .../include/mlir/Dialect/StandardOps/Utils/Utils.h | 6 + mlir/include/mlir/Dialect/Vector/VectorOps.td | 1 + mlir/lib/Dialect/StandardOps/Utils/Utils.cpp | 9 ++ mlir/lib/Dialect/Vector/VectorOps.cpp | 121 ++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 66 +++++++++++ 5 files changed, 202 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h index 1d2adc6..835841e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -78,6 +78,12 @@ public: Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr); +/// Similar to the other overload, but converts multiple OpFoldResults into +/// Values. +SmallVector +getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec); + /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 911a9c6..f24da79 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1277,6 +1277,7 @@ def Vector_TransferReadOp : "ArrayAttr":$inBounds)> ]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp index 3f66738..d52b8dc 100644 --- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp +++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp @@ -58,6 +58,15 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, return b.create(loc, attr.getValue().getSExtValue()); } +SmallVector +mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec) { + return llvm::to_vector<4>( + llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { + return getValueOrCreateConstantIndexOp(b, loc, value); + })); +} + Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 0b8f0ef..4ce8011 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorUtils.h" @@ -2649,6 +2650,74 @@ void TransferReadOp::getEffects( SideEffects::DefaultResource::get()); } +namespace { +/// Fold transfer_reads of a tensor.extract_slice op. E.g.: +/// +/// ``` +/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] +/// : tensor to tensor +/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} +/// : tensor, vector<4x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %p0 = addi %a, %e : index +/// %p1 = addi %b, %f : index +/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} +/// : tensor, vector<4x5xf32> +/// ``` +struct FoldExtractSliceIntoTransferRead + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransferReadOp xferOp, + PatternRewriter &rewriter) const override { + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (!xferOp.permutation_map().isIdentity()) + return failure(); + if (xferOp.mask()) + return failure(); + auto extractOp = xferOp.source().getDefiningOp(); + if (!extractOp) + return failure(); + if (!extractOp.hasUnitStride()) + return failure(); + + int64_t rankReduced = + extractOp.getSourceType().getRank() - extractOp.getType().getRank(); + SmallVector newIndices; + // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced + // indices first. + for (int64_t i = 0; i < rankReduced; ++i) { + OpFoldResult offset = extractOp.getMixedOffsets()[i]; + newIndices.push_back(getValueOrCreateConstantIndexOp( + rewriter, extractOp.getLoc(), offset)); + } + for (auto it : llvm::enumerate(xferOp.indices())) { + OpFoldResult offset = + extractOp.getMixedOffsets()[it.index() + rankReduced]; + newIndices.push_back( + rewriter.create(xferOp->getLoc(), it.value(), + getValueOrCreateConstantIndexOp( + rewriter, extractOp.getLoc(), offset))); + } + SmallVector inBounds(xferOp.getTransferRank(), true); + rewriter.replaceOpWithNewOp(xferOp, xferOp.getVectorType(), + extractOp.source(), newIndices, + xferOp.padding(), inBounds); + + return success(); + } +}; +} // namespace + +void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -2958,11 +3027,61 @@ public: return failure(); } }; + +/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write +/// could directly write to the insert_slice's destination. E.g.: +/// +/// ``` +/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} +/// : vector<4x5xf32>, tensor<4x5xf32> +/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] +/// : tensor<4x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} +/// : vector<4x5xf32>, tensor +/// ``` +struct FoldInsertSliceIntoTransferWrite + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + if (!insertOp.hasUnitStride()) + return failure(); + auto xferOp = insertOp.source().getDefiningOp(); + if (!xferOp) + return failure(); + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) + return failure(); + if (xferOp.mask()) + return failure(); + // Fold only if the TransferWriteOp completely overwrites the `source` with + // a vector. I.e., the result of the TransferWriteOp is a new tensor who's + // content is the data of the vector. + if (!llvm::equal(xferOp.getVectorType().getShape(), + xferOp.getShapedType().getShape())) + return failure(); + if (!xferOp.permutation_map().isIdentity()) + return failure(); + + SmallVector indices = getValueOrCreateConstantIndexOp( + rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); + SmallVector inBounds(xferOp.getTransferRank(), true); + rewriter.replaceOpWithNewOp( + insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds); + return success(); + } +}; } // namespace void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 2abcef9..8b3674e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -960,3 +960,69 @@ func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>, %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> return %0, %1 : vector<4xf32>, vector<1x1x4xf32> } + +// ----- + +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c4:.*]] = constant 4 : index +// CHECK-DAG: %[[c8:.*]] = constant 8 : index +// CHECK: %[[add:.*]] = addi %[[s1]], %[[c4]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = constant 3 : index + %c4 = constant 4 : index + %cst = constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c3:.*]] = constant 3 : index +// CHECK-DAG: %[[c5:.*]] = constant 5 : index +// CHECK-DAG: %[[c10:.*]] = constant 10 : index +// CHECK: %[[add:.*]] = addi %[[s1]], %[[c3]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = constant 3 : index + %c4 = constant 4 : index + %cst = constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK: %[[c3:.*]] = constant 3 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] +func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK-DAG: %[[c3:.*]] = constant 3 : index +// CHECK-DAG: %[[c4:.*]] = constant 4 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] +func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} -- 2.7.4