From ecd9dc0499880f2a89e4e03e9ffd3b368fe7e7ff Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 19 Sep 2022 02:04:39 -0700 Subject: [PATCH] [mlir][Transform] Add a new navigation op to retrieve the producer of an operand Given an opOperand uniquely determined by the operation `%op` and the operand number `num`, the `transform.get_producer_of_operand %op[num]` returns the handle to the unique operation that produced the SSA value used as opOperand. The transform fails if the operand is a block argument. Differential Revision: https://reviews.llvm.org/D134171 --- .../mlir/Dialect/Transform/IR/TransformOps.td | 19 +++++++++++++ mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 30 ++++++++++++++++++++ mlir/test/Dialect/Transform/test-interpreter.mlir | 33 ++++++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 361973d..99408eb 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -169,6 +169,25 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent let assemblyFormat = "$target attr-dict"; } +def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Get handle to the producer of this operation's operand number"; + let description = [{ + The handle defined by this Transform op corresponds to operation that + produces the SSA value defined by the `target` and `operand_number` + arguments. If the origin of the SSA value is not an operations (i.e. it is + a block argument), the transform silently fails. + The return handle points to only the subset of successfully produced + computational operations, which can be empty. + }]; + + let arguments = (ins PDL_Operation:$target, + I64Attr:$operand_number); + let results = (outs PDL_Operation:$parent); + let assemblyFormat = "$target `[` $operand_number `]` attr-dict"; +} + def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index ecab661..b9f9d44 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -387,6 +387,36 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( } //===----------------------------------------------------------------------===// +// GetProducerOfOperand +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetProducerOfOperand::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t operandNumber = getOperandNumber(); + SmallVector producers; + for (Operation *target : state.getPayloadOps(getTarget())) { + Operation *producer = + target->getNumOperands() <= operandNumber + ? nullptr + : target->getOperand(operandNumber).getDefiningOp(); + if (!producer) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "could not find a producer for operand number: " << operandNumber + << " of " << *target; + diag.attachNote(target->getLoc()) << "target op"; + results.set(getResult().cast(), + SmallVector{}); + return diag; + } + producers.push_back(producer); + } + results.set(getResult().cast(), producers); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 494c616..0a5688f 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -727,3 +727,36 @@ transform.with_pdl_patterns { transform.test_print_remark_at_operand %results, "transform applied" } } + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-remark @below {{found muli}} + %0 = arith.muli %arg0, %arg1 : index + arith.addi %0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 + %muli = get_producer_of_operand %addi[0] + transform.test_print_remark_at_operand %muli, "found muli" +} + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-note @below {{target op}} + %0 = arith.muli %arg0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %arg1 + // expected-error @below {{could not find a producer for operand number: 0 of}} + %bbarg = get_producer_of_operand %muli[0] + +} + -- 2.7.4