[mlir] allow repeated payload in structured.fuse_into_containing
authorAlex Zinenko <zinenko@google.com>
Mon, 15 May 2023 12:28:21 +0000 (12:28 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 15 May 2023 14:30:19 +0000 (14:30 +0000)
Structured fusion proceeds by iteratively finding the next suitable
producer to be fused into the loop. Therefore, it shouldn't matter if
the same producer is listed multiple times (e.g., it is used as multiple
operands). Adjust the implementation of the transform op to support this
case.

Also fix the checking code in the interpreter to actually respect the
TransformOpInterface indication that repeated payload is allowed, it
seems to have been accidentally dropped in one of the refactorings.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D150561

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

index cdeabb7..c7bc376 100644 (file)
@@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
 
 def FuseIntoContainingOp :
     Op<Transform_Dialect, "structured.fuse_into_containing_op",
-      [DeclareOpInterfaceMethods<TransformOpInterface>,
+      [DeclareOpInterfaceMethods<TransformOpInterface,
+          ["allowsRepeatedHandleOperands"]>,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Fuse a producer into a containing operation.";
 
@@ -160,7 +161,7 @@ def FuseIntoContainingOp :
     producer op handle may be associated with multiple payload ops. This
     transform fuses producers one-by-one, always picking an unspecified producer
     that has at least one use inside the containing op among the
-    producers.
+    producers. A producer can be listed multiple times in the handle.
 
     Note: If a producer has multiple uses inside the containing op, it is
     currently tiled and/or cloned multiple times into the containing op.
@@ -176,8 +177,8 @@ def FuseIntoContainingOp :
     containing op. I.e., "producers" that are not consumed within the containing
     op are rejected by this operation.
 
-    This operation reads and frees the producer handle.
-    This operation reads the containing op handle.
+    This operation consumes the producer handle.
+    This operation only reads the containing op handle.
   }];
 
   let arguments = (ins PDL_Operation:$producer_op,
index afef599..0703ca3 100644 (file)
@@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
   return fusedOp;
 }
 
+bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
+  // Allow repeated handles since we are fusing everything anyway.
+  return true;
+}
+
 DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
@@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 
   // Helper function to find the next producer that should be fused. Take any
   // producer that has a use inside the containing op.
-  SmallVector<Operation *> remainingProducers(producerOps.begin(),
-                                              producerOps.end());
+  SetVector<Operation *> remainingProducers(producerOps.begin(),
+                                            producerOps.end());
   auto getNextProducer = [&]() -> FailureOr<Operation *> {
     for (const auto &it : enumerate(remainingProducers)) {
       Operation *producerOp = it.value();
index bad1d74..5685187 100644 (file)
@@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--handle not consumed -> SKIP\n");
         continue;
       }
+      if (transform.allowsRepeatedHandleOperands()) {
+        FULL_LDBG("--op allows repeated handles -> SKIP\n");
+        continue;
+      }
       FULL_LDBG("--handle is consumed\n");
 
       Type operandType = operand.get().getType();
index d6b3ff3..537ee86 100644 (file)
@@ -247,3 +247,39 @@ module {
     transform.structured.fuse_into_containing_op %0 into %1
   }
 }
+
+// -----
+
+module {
+  // CHECK-LABEL: func.func @fuse_repeated
+  func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
+    %c0 = arith.constant 0.0 : f32
+    %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
+
+    // CHECK: scf.forall
+    %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
+      %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+      %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+      // CHECK: %[[FUSED:.+]] = linalg.fill
+      // CHECK: elemwise_unary ins(%[[FUSED]]
+      %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
+      }
+    }
+
+    return %1 : tensor<2xf32>
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+
+    // Create a new handle that points to `linalg.fill` twice.
+    %2 = transform.merge_handles %0, %0 : !pdl.operation
+
+    // It shouldn't be a problem to fuse this handle.
+    transform.structured.fuse_into_containing_op %2 into %1
+  }
+}