[mlir][Transform] Fix dropReverseMapping early exit condition
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 13 Oct 2022 15:13:25 +0000 (08:13 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 13 Oct 2022 15:30:45 +0000 (08:30 -0700)
Previously, the erasure would not trigger and result in surprising behavior.

Differential Revision: https://reviews.llvm.org/D135881

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

index 2810444ea864d26b5b1aca445e18334e2c6c2367..9b85af35783e72a138ef8f9c23d77c0b31657169 100644 (file)
@@ -92,7 +92,7 @@ transform::TransformState::setPayloadOps(Value value,
 void transform::TransformState::dropReverseMapping(Mappings &mappings,
                                                    Operation *op, Value value) {
   auto it = mappings.reverse.find(op);
-  if (it != mappings.reverse.end())
+  if (it == mappings.reverse.end())
     return;
 
   llvm::erase_value(it->getSecond(), value);
index 735491a05aa682715e70425db13998c372ee7a38..c7d02d2bb9341a75f7bbf0645a48b0b81f8ca69e 100644 (file)
@@ -895,3 +895,28 @@ transform.with_pdl_patterns {
     transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
   }
 }
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 : !pdl.operation failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
+    // here, the handles nested under are {%arg0, %arg1, %0}
+    // expected-remark @below {{3 handles nested under}}
+    transform.test_report_number_of_tracked_handles_nested_under %arg1
+    // expected-remark @below {{erased}}
+    transform.test_emit_remark_and_erase_operand %0, "erased"
+    // here, the handles nested under are only {%arg0, %arg1}
+    // expected-remark @below {{2 handles nested under}}
+    transform.test_report_number_of_tracked_handles_nested_under %arg1
+  }
+
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+}
+
+"test.some_op"() : () -> ()
index 6a39a2e2df6297f5effeace60b974b88a763f656..b890af57f8d008137a83b0b4a7e6c1eb6f2464c0 100644 (file)
@@ -328,6 +328,26 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  int64_t count = 0;
+  for (Operation *op : state.getPayloadOps(getTarget())) {
+    op->walk([&](Operation *nested) {
+      SmallVector<Value> handles;
+      (void)state.getHandlesForPayloadOp(nested, handles);
+      count += handles.size();
+    });
+  }
+  emitRemark() << count << " handles nested under";
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
index aaec014ac660afd8a20cd6acba8e38b3514b807f..9ca267565b3a1504fbd0ae6914930bd0e5806746 100644 (file)
@@ -253,4 +253,13 @@ def TestCopyPayloadOp
   let assemblyFormat = "$handle attr-dict";
 }
 
+def TestReportNumberOfTrackedHandlesNestedUnder
+  : Op<Transform_Dialect, "test_report_number_of_tracked_handles_nested_under",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins PDL_Operation:$target);
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD