[mlir][Transform] Add a transform.get_consumers_of_result navigation op
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 17 Jan 2023 14:25:21 +0000 (06:25 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 17 Jan 2023 14:58:38 +0000 (06:58 -0800)
Differential Revision: https://reviews.llvm.org/D141930

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir

index 72a819c..d72d38a 100644 (file)
@@ -189,6 +189,27 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
     "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+  let summary = "Get handle to the consumers of this operation's result number";
+  let description = [{
+    The handle defined by this Transform op corresponds to all operations that
+    consume the SSA value defined by the `target` and `result_number`
+    arguments.
+    This operation applies to a single payload operation, otherwise it 
+    definitely fails.
+    The return handle points to the consuming operations operations, which can
+    be empty.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$result_number);
+  let results = (outs TransformHandleTypeInterface:$consumers);
+  let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+                       "functional-type(operands, results)";
+}
+
 def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
index 0b3391e..f711419 100644 (file)
@@ -400,6 +400,31 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
 }
 
 //===----------------------------------------------------------------------===//
+// GetConsumersOfResult
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetConsumersOfResult::apply(transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  int64_t resultNumber = getResultNumber();
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  if (payloadOps.empty()) {
+    results.set(getResult().cast<OpResult>(), {});
+    return DiagnosedSilenceableFailure::success();
+  }
+  if (payloadOps.size() != 1)
+    return emitDefiniteFailure()
+           << "handle must be mapped to exactly one payload op";
+
+  Operation *target = payloadOps.front();
+  if (target->getNumResults() <= resultNumber)
+    return emitDefiniteFailure() << "result number overflow";
+  results.set(getResult().cast<OpResult>(),
+              llvm::to_vector(target->getResult(resultNumber).getUsers()));
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
 // GetProducerOfOperand
 //===----------------------------------------------------------------------===//
 
index da48fe2..30d1155 100644 (file)
@@ -775,6 +775,53 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @get_consumer(%arg0: index, %arg1: index) {
+  %0 = arith.muli %arg0, %arg1 : index
+  // expected-remark @below {{found addi}}
+  arith.addi %0, %arg1 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+  %addi = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+  transform.test_print_remark_at_operand %addi, "found addi" : !pdl.operation
+}
+
+// -----
+
+func.func @get_consumer_fail_1(%arg0: index, %arg1: index) {
+  %0 = arith.muli %arg0, %arg1 : index
+  %1 = arith.muli %arg0, %arg1 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+  // expected-error @below {{handle must be mapped to exactly one payload op}}
+  %bbarg = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
+func.func @get_consumer_fail_2(%arg0: index, %arg1: index) {
+  %0 = arith.muli %arg0, %arg1 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+  // expected-error @below {{result number overflow}}
+  %bbarg = get_consumers_of_result %muli[1] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
 func.func @split_handles(%a: index, %b: index, %c: index) {
   %0 = arith.muli %a, %b : index
   %1 = arith.muli %a, %c : index