From 12dcb89dadf4f37f7781ce687ab06b202e1b8ba3 Mon Sep 17 00:00:00 2001 From: Pierre Oechsel Date: Tue, 21 Apr 2020 11:43:28 +0200 Subject: [PATCH] [mlir] [linalg] Only promote selected buffers. The promotion transformation is promoting all input and output buffers of the transformed op. The user might want to only promote some of these buffers. Differential Revision: https://reviews.llvm.org/D78498 --- .../Linalg/Transforms/LinalgTransformPatterns.td | 5 +++ .../Dialect/Linalg/Transforms/LinalgTransforms.h | 8 ++++ .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 29 +++++++++++-- mlir/test/Dialect/Linalg/transform-patterns.mlir | 50 ++++++++++++++++++++++ .../TestLinalgTransformPatterns.td | 8 ++++ 5 files changed, 97 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 7fa33e4..2eaed14 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -114,4 +114,9 @@ def PreconditionPromoteSubviewsLinalgOp : CPred< "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; def PromoteSubviewsLinalgOp : NativeCodeCall< "promoteSubviewsLinalgOp($_builder, op)">; + +class PromoteSelectedSubviewsLinalgOp operands, string marker=""> : + NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" # + StrJoinInt.result # "}, \"" # marker # "\")">; + #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index c65909e..e7a8925 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -121,6 +121,14 @@ LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op); SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); +/// Similar to `promoteSubviewsLinalgOp` but only tries to promote +/// the views corresponding to the operands specified in +/// `operandIndicesToPromote`. +/// If linalgMarker is specified and the transformation is successfull +/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. +SmallVector promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker = ""); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 5b3618d..e96ee27 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -338,6 +338,24 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); + LinalgOp linOp = cast(op); + SmallVector toPromote; + int64_t nBuffers = linOp.getNumInputsAndOutputBuffers(); + toPromote.reserve(nBuffers); + for (int64_t i = 0; i < nBuffers; ++i) + toPromote.push_back(i); + return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); +} + +SmallVector mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " + << *op << ":\n"); + + assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + if (auto convOp = dyn_cast(op)) { // TODO(ntv): add a level of indirection to linalg.generic. if (convOp.padding()) @@ -348,11 +366,16 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, assert(linOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputBuffers()) - if (auto sv = dyn_cast_or_null(it.getDefiningOp())) + for (int64_t index : operandIndicesToPromote) + if (auto sv = + dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) subViews.insert(sv); + if (!subViews.empty()) { - promoteSubViewOperands(rewriter, linOp, subViews); + auto newOp = promoteSubViewOperands(rewriter, linOp, subViews); + if (!linalgMarker.empty()) + newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); return {}; } llvm_unreachable("DRR failure case must be a precondition"); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 7f76819..3e8230c 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref, // CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref, memref // CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref, memref, memref + +func @promote_first_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = constant 2000 : index + %c3000 = constant 3000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg0, 0 : memref + %1 = dim %arg0, 1 : memref + %2 = dim %arg1, 1 : memref + loop.for %arg3 = %c0 to %0 step %c2000 { + loop.for %arg4 = %c0 to %2 step %c3000 { + loop.for %arg5 = %c0 to %1 step %c4000 { + %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} : + memref, + memref, + memref + } + } + } + return +} +// CHECK-LABEL: func @promote_first_subview_matmul +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c2000 { +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c3000 { +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c4000 { +// CHECK: %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[a0:.*]] = alloc({{%.*}}) : memref +// CHECK: %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[a1:.*]] = alloc({{%.*}}) : memref +// CHECK-NOT: %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[a2:.*]] = alloc({{%.*}}) : memref +// CHECK-NOT: %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref +// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref, memref +// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref, memref^ +// CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref, memref, memref diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index a55cdbf..8444f4c 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -149,4 +149,12 @@ def : Pat<(MatmulOp:$op $_, $_, $_), HasLinalgTransformMarker<"_promote_views_">]>> )]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">), + [(Constraint, + HasLinalgTransformMarker<"_promote_first_view_">]>> + )]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS -- 2.7.4