[mlir] [linalg] Only promote selected buffers.
authorPierre Oechsel <pierre.oechsel@gmail.com>
Tue, 21 Apr 2020 09:43:28 +0000 (11:43 +0200)
committerAlex Zinenko <zinenko@google.com>
Tue, 21 Apr 2020 09:50:08 +0000 (11:50 +0200)
The promotion transformation is promoting all input and output buffers of the transformed op. The user might want to only promote some of these buffers.

Differential Revision: https://reviews.llvm.org/D78498

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

index 7fa33e4..2eaed14 100644 (file)
@@ -114,4 +114,9 @@ def PreconditionPromoteSubviewsLinalgOp : CPred<
   "succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
 def PromoteSubviewsLinalgOp : NativeCodeCall<
   "promoteSubviewsLinalgOp($_builder, op)">;
+
+class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker=""> :
+  NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
+    StrJoinInt<operands>.result # "}, \"" # marker # "\")">;
+
 #endif // LINALG_TRANSFORMS
index c65909e..e7a8925 100644 (file)
@@ -121,6 +121,14 @@ LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
 SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
                                               Operation *op);
 
+/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
+/// the views corresponding to the operands specified in
+/// `operandIndicesToPromote`.
+/// If linalgMarker is specified and the transformation is successfull
+/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
+SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op,
+    ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "");
 } // namespace linalg
 } // namespace mlir
 
index 5b3618d..e96ee27 100644 (file)
@@ -338,6 +338,24 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
   assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
          "DRR failure case must be a precondition");
 
+  LinalgOp linOp = cast<LinalgOp>(op);
+  SmallVector<int64_t, 4> toPromote;
+  int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
+  toPromote.reserve(nBuffers);
+  for (int64_t i = 0; i < nBuffers; ++i)
+    toPromote.push_back(i);
+  return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
+}
+
+SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op,
+    ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker) {
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
+                    << *op << ":\n");
+
+  assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
+         "DRR failure case must be a precondition");
+
   if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
     // TODO(ntv): add a level of indirection to linalg.generic.
     if (convOp.padding())
@@ -348,11 +366,16 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
   assert(linOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   SetVector<Value> subViews;
-  for (auto it : linOp.getInputsAndOutputBuffers())
-    if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
+  for (int64_t index : operandIndicesToPromote)
+    if (auto sv =
+            dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
       subViews.insert(sv);
+
   if (!subViews.empty()) {
-    promoteSubViewOperands(rewriter, linOp, subViews);
+    auto newOp = promoteSubViewOperands(rewriter, linOp, subViews);
+    if (!linalgMarker.empty())
+      newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
+                    rewriter.getStringAttr(linalgMarker));
     return {};
   }
   llvm_unreachable("DRR failure case must be a precondition");
index 7f76819..3e8230c 100644 (file)
@@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK      :         linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK      :         linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK      :         linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
+func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                             %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                             %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  %c2000 = constant 2000 : index
+  %c3000 = constant 3000 : index
+  %c4000 = constant 4000 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  loop.for %arg3 = %c0 to %0 step %c2000 {
+    loop.for %arg4 = %c0 to %2 step %c3000 {
+      loop.for %arg5 = %c0 to %1 step %c4000 {
+        %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} :
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>
+      }
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @promote_first_subview_matmul
+// CHECK:   loop.for {{.*}} = %c0 to {{.*}} step %c2000 {
+// CHECK:     loop.for {{.*}} = %c0 to {{.*}} step %c3000 {
+// CHECK:       loop.for {{.*}} = %c0 to {{.*}} step %c4000 {
+// CHECK:         %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK:         %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK:         %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map:.*]]>
+// CHECK-NOT:     %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT:     %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT:     %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK-NOT:     %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT:     %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT:     %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK:         linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT:     linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT:     linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
+// CHECK:         linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[map]]>, memref<?x?xf32, #[[map]]>
index a55cdbf..8444f4c 100644 (file)
@@ -149,4 +149,12 @@ def : Pat<(MatmulOp:$op $_, $_, $_),
               HasLinalgTransformMarker<"_promote_views_">]>>
            )]>;
 
+def : Pat<(MatmulOp:$op $_, $_, $_),
+          (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
+          [(Constraint<And<[
+              PreconditionPromoteSubviewsLinalgOp,
+              HasOperandsOfType<"SubViewOp">,
+              HasLinalgTransformMarker<"_promote_first_view_">]>>
+           )]>;
+
 #endif // TEST_LINALG_TRANSFORMS_PATTERNS