[mlir] optionally allow repeated handles in transform dialect
authorAlex Zinenko <zinenko@google.com>
Thu, 15 Dec 2022 18:48:09 +0000 (18:48 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 19 Dec 2022 09:02:03 +0000 (09:02 +0000)
Some operations may be able to deal with handles pointing to the same
operation when the handle is consumed. For example, merge handles with
deduplication doesn't actually destroy payload operations and is
specifically intended to remove the situation with duplicates. Add a
method to the transform interface to allow ops to declare they can
support repeated handles.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D140124

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/expensive-checks.mlir

index 3378153..c760cb1 100644 (file)
@@ -48,6 +48,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
           "::mlir::transform::TransformResults &":$transformResults,
           "::mlir::transform::TransformState &":$state
     )>,
+    InterfaceMethod<
+      /*desc=*/[{
+        Indicates whether the op instance allows its handle operands to be
+        associated with the same payload operations.
+      }],
+      /*returnType=*/"bool",
+      /*name=*/"allowsRepeatedHandleOperands",
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return false;
+      }]
+    >,
   ];
 
   let extraSharedClassDeclaration = [{
index f6788de..c813a64 100644 (file)
@@ -210,7 +210,7 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
 }
 
 def MergeHandlesOp : TransformDialectOp<"merge_handles",
-    [DeclareOpInterfaceMethods<TransformOpInterface>,
+    [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      SameOperandsAndResultType]> {
   let summary = "Merges handles into one pointing to the union of payload ops";
index 9b136cc..10a4381 100644 (file)
@@ -189,7 +189,8 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
   for (OpOperand &target : transform->getOpOperands()) {
     // If the operand uses an invalidated handle, report it.
     auto it = invalidatedHandles.find(target.get());
-    if (it != invalidatedHandles.end())
+    if (!transform.allowsRepeatedHandleOperands() &&
+        it != invalidatedHandles.end())
       return it->getSecond()(transform->getLoc()), failure();
 
     // Invalidate handles pointing to the operations nested in the operation
@@ -201,6 +202,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
     if (llvm::any_of(effects, consumesTarget))
       recordHandleInvalidation(target);
   }
+
   return success();
 }
 
index f7a8ee1..629b1f6 100644 (file)
@@ -449,6 +449,11 @@ transform::MergeHandlesOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
+  // Handles may be the same if deduplicating is enabled.
+  return getDeduplicate();
+}
+
 void transform::MergeHandlesOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getHandles(), effects);
index 550ede6..abc09d7 100644 (file)
@@ -99,3 +99,17 @@ module {
     transform.test_consume_operand %1, %2
   }
 }
+
+// -----
+
+// Deduplication attribute allows "merge_handles" to take repeated operands.
+
+module {
+
+  transform.sequence failures(propagate) {
+  ^bb0(%0: !pdl.operation):
+    %1 = transform.test_copy_payload %0
+    %2 = transform.test_copy_payload %0
+    transform.merge_handles %1, %2 { deduplicate } : !pdl.operation
+  }
+}