From 4cf936d01f36adaa268fdc1411f27be197434cb2 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 21 Feb 2023 08:57:16 +0100 Subject: [PATCH] [mlir][transform] Add transform.get_defining_op op This op is the inverse of `transform.get_result`. Differential Revision: https://reviews.llvm.org/D144409 --- .../mlir/Dialect/Transform/IR/TransformOps.td | 17 +++++++++++ mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 21 ++++++++++++++ mlir/test/Dialect/Transform/test-interpreter.mlir | 33 ++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 7a74617..8865865 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -210,6 +210,23 @@ def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result", "functional-type(operands, results)"; } +def GetDefiningOp : TransformDialectOp<"get_defining_op", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Get handle to the defining op of a value"; + let description = [{ + The handle defined by this Transform op corresponds to the defining op of + the targeted value. + + This transform fails silently if the targeted value is a block argument. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$result); + let assemblyFormat = "$target attr-dict `:` " + "functional-type(operands, results)"; +} + def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index a7b370c..6632206 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -448,6 +448,27 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results, } //===----------------------------------------------------------------------===// +// GetDefiningOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetDefiningOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector definingOps; + for (Value v : state.getPayloadValues(getTarget())) { + if (v.isa()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "cannot get defining op of block argument"; + diag.attachNote(v.getLoc()) << "target value"; + return diag; + } + definingOps.push_back(v.getDefiningOp()); + } + results.set(getResult().cast(), definingOps); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // GetProducerOfOperand //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index f605686..3a7f420 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1222,3 +1222,36 @@ transform.sequence failures(propagate) { %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value } + +// ----- + +func.func @get_result_of_op(%arg0: index, %arg1: index) -> index { + // expected-remark @below {{matched}} + %r = arith.addi %arg0, %arg1 : index + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %result = transform.get_result %addi[0] : (!transform.any_op) -> !transform.any_value + %op = transform.get_defining_op %result : (!transform.any_value) -> !transform.any_op + transform.test_print_remark_at_operand %op, "matched" : !transform.any_op +} + +// ----- + +// expected-note @below {{target value}} +func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index { + %r = arith.addi %arg0, %arg1 : index + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %bbarg = test_produce_value_handle_to_argument_of_parent_block %addi, 0 : (!transform.any_op) -> !transform.any_value + // expected-error @below {{cannot get defining op of block argument}} + %op = transform.get_defining_op %bbarg : (!transform.any_value) -> !transform.any_op + transform.test_print_remark_at_operand %op, "matched" : !transform.any_op +} -- 2.7.4