"$target attr-dict `:` functional-type(operands, results)";
}
+def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ let summary = "Get handle to the consumers of this operation's result number";
+ let description = [{
+ The handle defined by this Transform op corresponds to all operations that
+ consume the SSA value defined by the `target` and `result_number`
+ arguments.
+ This operation applies to a single payload operation, otherwise it
+ definitely fails.
+ The return handle points to the consuming operations operations, which can
+ be empty.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$result_number);
+ let results = (outs TransformHandleTypeInterface:$consumers);
+ let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
}
//===----------------------------------------------------------------------===//
+// GetConsumersOfResult
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetConsumersOfResult::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ int64_t resultNumber = getResultNumber();
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ if (payloadOps.empty()) {
+ results.set(getResult().cast<OpResult>(), {});
+ return DiagnosedSilenceableFailure::success();
+ }
+ if (payloadOps.size() != 1)
+ return emitDefiniteFailure()
+ << "handle must be mapped to exactly one payload op";
+
+ Operation *target = payloadOps.front();
+ if (target->getNumResults() <= resultNumber)
+ return emitDefiniteFailure() << "result number overflow";
+ results.set(getResult().cast<OpResult>(),
+ llvm::to_vector(target->getResult(resultNumber).getUsers()));
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//
// -----
+func.func @get_consumer(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ // expected-remark @below {{found addi}}
+ arith.addi %0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ %addi = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+ transform.test_print_remark_at_operand %addi, "found addi" : !pdl.operation
+}
+
+// -----
+
+func.func @get_consumer_fail_1(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.muli %arg0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ // expected-error @below {{handle must be mapped to exactly one payload op}}
+ %bbarg = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
+func.func @get_consumer_fail_2(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %arg1
+ // expected-error @below {{result number overflow}}
+ %bbarg = get_consumers_of_result %muli[1] : (!pdl.operation) -> !pdl.operation
+
+}
+
+// -----
+
func.func @split_handles(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index