[mlir][MemRef][Transform] Don't apply multibuffer on "useless" allocs
authorQuentin Colombet <quentin.colombet@gmail.com>
Fri, 10 Feb 2023 16:21:38 +0000 (17:21 +0100)
committerQuentin Colombet <quentin.colombet@gmail.com>
Mon, 13 Feb 2023 13:19:10 +0000 (14:19 +0100)
`alloc`s that have users outside of loops are guaranteed to fail in
`multibuffer`.

Instead of exposing ourselves to that failure in the transform dialect,
filter out the `alloc`s that fall in this category.

To be able to do this filtering we have to change the `multibuffer`
transform op from `TransformEachOpTrait` to a plain `TransformOp`. This is
because `TransformEachOpTrait` expects that every successful `applyToOne`
returns a non-empty result.

Couple of notes:
- I changed the assembly syntax to make sure we only get `alloc` ops as
  input. (And added a test case to make sure we reject invalid inputs.)
- `multibuffer` can still fail pretty easily when you know its limitations.
  See the updated `op failed to multibuffer` test case for instance.
  Longer term, instead of leaking/coupling the actual implementation (in
  this case the checks normally done in `memref::multiBuffer`) with the
  transform dialect (the added check in `::apply`), we may want to refactor
  how we structure the underlying implementation. E.g., we could imagine a
  `canApply` method for all the implementations that we want to hook up in
  the transform dialect.
  This has some implications on how not to duplicate work between
  `canApply` and the actual implementation but I thought I throw that here
  to have us think about it :).

Differential Revision: https://reviews.llvm.org/D143747

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/test/Dialect/MemRef/transform-ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index f16fe8a..c22cec8 100644 (file)
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+
 def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait]> {
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Multibuffers an allocation";
   let description = [{
      Transformation to do multi-buffering/array expansion to remove
@@ -33,19 +36,13 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
   }];
 
   let arguments =
-      (ins PDL_Operation:$target,
+      (ins Transform_MemRefAllocOp:$target,
            ConfinedAttr<I64Attr, [IntPositive]>:$factor);
 
   let results = (outs PDL_Operation:$transformed);
 
-  let assemblyFormat = "$target attr-dict";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        memref::AllocOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type(operands, results)";
 }
 
 #endif // MEMREF_TRANSFORM_OPS
index 11e8d25..b98db40 100644 (file)
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
   MLIRArithDialect
   MLIRIR
   MLIRPDLDialect
+  MLIRLoopLikeInterface
   MLIRMemRefDialect
   MLIRMemRefTransforms
   MLIRTransformDialect
index 1b53686..85356d2 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 
 using namespace mlir;
 
@@ -21,15 +22,33 @@ using namespace mlir;
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne(
-    memref::AllocOp target, transform::ApplyToEachResultList &results,
+DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
+    transform::TransformResults &transformResults,
     transform::TransformState &state) {
-  auto newBuffer = memref::multiBuffer(target, getFactor());
-  if (failed(newBuffer))
-    return emitSilenceableFailure(target->getLoc())
-           << "op failed to multibuffer";
+  SmallVector<Operation *> results;
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  for (auto *op : payloadOps) {
+    bool canApplyMultiBuffer = true;
+    auto target = cast<memref::AllocOp>(op);
+    // Skip allocations not used in a loop.
+    for (Operation *user : target->getUsers()) {
+      auto loop = user->getParentOfType<LoopLikeOpInterface>();
+      if (!loop) {
+        canApplyMultiBuffer = false;
+        break;
+      }
+    }
+    if (!canApplyMultiBuffer)
+      continue;
 
-  results.push_back(*newBuffer);
+    auto newBuffer = memref::multiBuffer(target, getFactor());
+    if (failed(newBuffer))
+      return emitSilenceableFailure(target->getLoc())
+             << "op failed to multibuffer";
+
+    results.push_back(*newBuffer);
+  }
+  transformResults.set(getResult().cast<OpResult>(), results);
   return DiagnosedSilenceableFailure::success();
 }
 
index ddb6407..df25fec 100644 (file)
@@ -30,26 +30,102 @@ func.func @multi_buffer(%in: memref<16xf32>) {
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
   // Verify that the returned handle is usable.
   transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
 }
 
 // -----
 
-// Trying to use multibuffer on alloc that are used outside of loops is
-// going to fail.
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer_on_affine_loop
+func.func @multi_buffer_on_affine_loop(%in: memref<16xf32>) {
+  // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+  // expected-remark @below {{transformed}}
+  %tmp = memref.alloc() : memref<4xf32>
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+
+  // CHECK: affine.for %[[IV:.*]] = 0
+  affine.for %i0 = 0 to 16 step 4 {
+    // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
+    // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+    // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+    memref.copy %1, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+    "some_use"(%tmp) : (memref<4xf32>) ->()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+  // Verify that the returned handle is usable.
+  transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
+}
+
+// -----
+
+// Trying to use multibuffer on allocs that are used in different loops
+// with none dominating the other is going to fail.
 // Check that we emit a proper error for that.
-func.func @multi_buffer_uses_outside_of_loop(%in: memref<16xf32>) {
+func.func @multi_buffer_uses_with_no_loop_dominator(%in: memref<16xf32>, %cond: i1) {
   // expected-error @below {{op failed to multibuffer}}
   %tmp = memref.alloc() : memref<4xf32>
 
-  "some_outside_loop_use"(%tmp) : (memref<4xf32>) -> ()
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c16 = arith.constant 16 : index
+  scf.if %cond {
+    scf.for %i0 = %c0 to %c16 step %c4 {
+      %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+      memref.copy %var, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+      "some_use"(%tmp) : (memref<4xf32>) ->()
+    }
+  }
+
+  scf.for %i0 = %c0 to %c16 step %c4 {
+    %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+    memref.copy %1, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+    "some_use"(%tmp) : (memref<4xf32>) ->()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+}
+
+// -----
+
+// Make sure the multibuffer operation is typed so that it only supports
+// memref.alloc.
+// Check that we emit an error if we try to match something else.
+func.func @multi_buffer_reject_alloca(%in: memref<16xf32>, %cond: i1) {
+  %tmp = memref.alloca() : memref<4xf32>
 
   %c0 = arith.constant 0 : index
   %c4 = arith.constant 4 : index
   %c16 = arith.constant 16 : index
+  scf.if %cond {
+    scf.for %i0 = %c0 to %c16 step %c4 {
+      %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+      memref.copy %var, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+      "some_use"(%tmp) : (memref<4xf32>) ->()
+    }
+  }
 
   scf.for %i0 = %c0 to %c16 step %c4 {
     %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
@@ -62,6 +138,50 @@ func.func @multi_buffer_uses_outside_of_loop(%in: memref<16xf32>) {
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+  %0 = transform.structured.match ops{["memref.alloca"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloca">
+  // expected-error @below {{'transform.memref.multibuffer' op operand #0 must be Transform IR handle to memref.alloc operations, but got '!transform.op<"memref.alloca">'}}
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloca">) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer_one_alloc_with_use_outside_of_loop
+// Make sure we manage to apply multi_buffer to the memref that is used in
+// the loop (%tmp) and don't error out for the one that is not (%tmp2).
+func.func @multi_buffer_one_alloc_with_use_outside_of_loop(%in: memref<16xf32>) {
+  // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+  // expected-remark @below {{transformed}}
+  %tmp = memref.alloc() : memref<4xf32>
+  %tmp2 = memref.alloc() : memref<4xf32>
+
+  "some_use_outside_of_loop"(%tmp2) : (memref<4xf32>) -> ()
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c16 = arith.constant 16 : index
+
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]]
+  scf.for %i0 = %c0 to %c16 step %c4 {
+    // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
+    // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+    // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+    memref.copy %1, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+    "some_use"(%tmp) : (memref<4xf32>) ->()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+  // Verify that the returned handle is usable.
+  transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
 }
index 7bf3791..75db8e9 100644 (file)
@@ -9954,6 +9954,7 @@ cc_library(
         ":AffineDialect",
         ":ArithDialect",
         ":IR",
+        ":LoopLikeInterface",
         ":MemRefDialect",
         ":MemRefTransformOpsIncGen",
         ":MemRefTransforms",