[mlir] make `fuse_into_containing_op` preserve the containing op handle
authorAlex Zinenko <zinenko@google.com>
Fri, 26 May 2023 11:26:52 +0000 (11:26 +0000)
committerAlex Zinenko <zinenko@google.com>
Fri, 26 May 2023 16:01:40 +0000 (16:01 +0000)
This partially undoes the intent of https://reviews.llvm.org/D151418 by
cheating its way to keep the "containing op" (aka loop) handle read-only
in fusion. It is crucial to do so for composability of tiling and
fusion. Specfically, after the "containing op" handle started being
consumed, it became impossible to perform additional tiling after fusion
except tiling the last-fused op:

  %tiled1, %loop1 = tile %op
  %producer1, %loop2 = fuse %producer into %loop1
  // invalid, because %tiled1 is invalidated by consuming %loop1
  // that points to its parent
  tile %tiled1

or

  %tiled1, %loop1 = tile %op
  %tiled2, %loop2 = tile %tiled1
  %p2 = fuse %producer into %loop1
  // invalid, because %loop2 is invalidated by consuming %loop1
  // that points to its parent
  fuse %p2 into %loop2

The approach here makes creative use of the state extension mechanism to
update the payload operation associted with the operand handle. Further
investigation is necessary to understand if is consistent with the
overall execution model of the transform dialect, but it is crucial to
restore composability ASAP.

Reviewed By: springerm, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D151555

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-ops.mlir

index 52699db..f18f24d 100644 (file)
@@ -34,6 +34,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -663,6 +664,36 @@ bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
   return true;
 }
 
+namespace {
+/// Unsafely exposes an internal protected method of TransformState::Extension
+/// as public.
+///
+/// MUST NOT be used directly.
+class UnsafeOpReplacementStateExtension : public TransformState::Extension {
+public:
+  UnsafeOpReplacementStateExtension(TransformState &state)
+      : TransformState::Extension(state) {}
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      UnsafeOpReplacementStateExtension)
+
+  LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
+    return replacePayloadOp(op, replacement);
+  }
+};
+} // namespace
+
+/// Replaces `payload` with `replacement` in all handles stored in the state.
+/// MUST NOT be used except for the case immediately below.
+static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
+                                                      Operation *payload,
+                                                      Operation *replacement) {
+  UnsafeOpReplacementStateExtension extension(state);
+  // This may return failure if the payload is not associated with any handle,
+  // ignore that.
+  (void)extension.doReplacePayloadOp(payload, replacement);
+}
+
 DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
@@ -757,6 +788,14 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
+  // Update handles associated with the containing op so we don't need to
+  // invalidate them. This is a hack to support better composability between
+  // tiling and fusion while a proper mechanism is being investigated.
+  //
+  // DO NOT replicate this elsewhere unless you understand what you are doing.
+  forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
+                                            containingOp);
+
   results.set(cast<OpResult>(getFusedOp()), fusedOps);
   results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
   return DiagnosedSilenceableFailure::success();
@@ -765,7 +804,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 void transform::FuseIntoContainingOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getProducerOp(), effects);
-  consumesHandle(getContainingOp(), effects);
+  onlyReadsHandle(getContainingOp(), effects);
   producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
index dd85008..8d1c280 100644 (file)
@@ -35,3 +35,15 @@ transform.sequence failures(propagate) {
   // CHECK: transform.structured.scalarize
   %0 = transform.structured.scalarize %arg0 : (!transform.any_op) -> !transform.any_op
 }
+
+// Check that the second argument of `fuse_into_containing_op` is not consumed
+// (if it had been, we would have seen a diagnostic about multiple consumers).
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  %loop = transform.structured.match ops{["scf.forall"]} in %arg0
+    : (!transform.any_op) -> !transform.any_op
+  %0:2 = transform.structured.fuse_into_containing_op %arg1 into %loop
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %1:2 = transform.structured.fuse_into_containing_op %arg2 into %loop
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+}