From 81233c70cbf6a343d4e86c5ec99dfd427b25e975 Mon Sep 17 00:00:00 2001 From: max Date: Sun, 7 May 2023 18:19:46 -0500 Subject: [PATCH] [MLIR][python bindings] Add `PyValue.print_as_operand` (`Value::printAsOperand`) Useful for easier debugging (no need to regex out all of the stuff around the id). Differential Revision: https://reviews.llvm.org/D149902 --- mlir/include/mlir-c/IR.h | 6 ++ mlir/include/mlir/IR/Value.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 17 ++++++ mlir/lib/CAPI/IR/IR.cpp | 8 +++ mlir/lib/IR/AsmPrinter.cpp | 28 +++++++-- mlir/test/python/ir/value.py | 95 +++++++++++++++++++++++++++++ 6 files changed, 150 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 90af14461e29..13a3cb0130ce 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -776,6 +776,12 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Prints a value as an operand (i.e., the ValueID). +MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, + MlirOpPrintingFlags flags, + MlirStringCallback callback, + void *userData); + /// Returns an op operand representing the first use of the value, or a null op /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 64ced152839f..7a8aee29ca44 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -226,6 +226,7 @@ public: /// Print this value as if it were an operand. void printAsOperand(raw_ostream &os, AsmState &state); + void printAsOperand(raw_ostream &os, const OpPrintingFlags &flags); /// Methods for supporting PointerLikeTypeTraits. void *getAsOpaquePointer() const { return impl; } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f2e188e78c4a..7ffa464009fc 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -156,6 +156,10 @@ position in the argument list. If the value is an operation result, this is equivalent to printing the operation that produced it. )"; +static const char kGetNameAsOperand[] = + R"(Returns the string form of value as an operand (i.e., the ValueID). +)"; + static const char kValueReplaceAllUsesWithDocstring[] = R"(Replace all uses of value with the new value, updating anything in the IR that uses 'self' to use the other value instead. @@ -3336,6 +3340,19 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, kValueDunderStrDocstring) + .def( + "get_name", + [](PyValue &self, bool useLocalScope) { + PyPrintAccumulator printAccum; + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(), + printAccum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); + return printAccum.join(); + }, + py::arg("use_local_scope") = false, kGetNameAsOperand) .def_property_readonly("type", [](PyValue &self) { return PyType( diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6ed32e1ce865..79386dedfdd9 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -767,6 +768,13 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } +void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + Value cppValue = unwrap(value); + cppValue.printAsOperand(stream, *unwrap(flags)); +} + MlirOpOperand mlirValueGetFirstUse(MlirValue value) { Value cppValue = unwrap(value); if (cppValue.use_empty()) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 3c525a6e7635..3afafd6d4cdf 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -44,8 +44,8 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" -#include #include +#include using namespace mlir; using namespace mlir::detail; @@ -3673,10 +3673,7 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) { os); } -void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { - // Find the operation to number from based upon the provided flags. - Operation *op = this; - bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); +static Operation *findParent(Operation *op, bool shouldUseLocalScope) { do { // If we are printing local scope, stop at the first operation that is // isolated from above. @@ -3689,7 +3686,28 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { break; op = parentOp; } while (true); + return op; +} + +void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) { + Operation *op; + if (auto result = dyn_cast()) { + op = result.getOwner(); + } else { + op = cast().getOwner()->getParentOp(); + if (!op) { + os << "<>"; + return; + } + } + op = findParent(op, flags.shouldUseLocalScope()); + AsmState state(op, flags); + printAsOperand(os, state); +} +void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { + // Find the operation to number from based upon the provided flags. + Operation *op = findParent(this, printerFlags.shouldUseLocalScope()); AsmState state(op, printerFlags); print(os, state); } diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 90fe64ac1762..66568c426216 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -2,6 +2,7 @@ import gc from mlir.ir import * +from mlir.dialects import func def run(f): @@ -90,6 +91,7 @@ def testValueHash(): assert hash(block.arguments[0]) == hash(op.operands[0]) assert hash(op.result) == hash(ret.operands[0]) + # CHECK-LABEL: TEST: testValueUses @run def testValueUses(): @@ -112,6 +114,7 @@ def testValueUses(): print(f"Use owner: {use.owner}") print(f"Use operand_number: {use.operand_number}") + # CHECK-LABEL: TEST: testValueReplaceAllUsesWith @run def testValueReplaceAllUsesWith(): @@ -137,3 +140,95 @@ def testValueReplaceAllUsesWith(): assert use.owner in [op1, op2] print(f"Use owner: {use.owner}") print(f"Use operand_number: {use.operand_number}") + + +# CHECK-LABEL: TEST: testValuePrintAsOperand +@run +def testValuePrintAsOperand(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) + print(value) + + value2 = Operation.create("custom.op2", results=[i32]).results[0] + # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32) + print(value2) + + f = func.FuncOp("test", ([i32, i32], [])) + entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32]) + + with InsertionPoint(entry_block1): + value3 = Operation.create("custom.op3", results=[i32]).results[0] + # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) + print(value3) + value4 = Operation.create("custom.op4", results=[i32]).results[0] + # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) + print(value4) + + f = func.FuncOp("test", ([i32, i32], [])) + entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32]) + with InsertionPoint(entry_block2): + value5 = Operation.create("custom.op5", results=[i32]).results[0] + # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32) + print(value5) + value6 = Operation.create("custom.op6", results=[i32]).results[0] + # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32) + print(value6) + + func.ReturnOp([]) + + func.ReturnOp([]) + + # CHECK: %[[VAL1]] + print(value.get_name()) + # CHECK: %[[VAL2]] + print(value2.get_name()) + # CHECK: %[[VAL3]] + print(value3.get_name()) + # CHECK: %[[VAL4]] + print(value4.get_name()) + + # CHECK: %0 + print(value3.get_name(use_local_scope=True)) + # CHECK: %1 + print(value4.get_name(use_local_scope=True)) + + # CHECK: %[[VAL5]] + print(value5.get_name()) + # CHECK: %[[VAL6]] + print(value6.get_name()) + + # CHECK: %[[ARG0:.*]] + print(entry_block1.arguments[0].get_name()) + # CHECK: %[[ARG1:.*]] + print(entry_block1.arguments[1].get_name()) + + # CHECK: %[[ARG2:.*]] + print(entry_block2.arguments[0].get_name()) + # CHECK: %[[ARG3:.*]] + print(entry_block2.arguments[1].get_name()) + + # CHECK: module { + # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 + # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32 + # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { + # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 + # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 + # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) { + # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32 + # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32 + # CHECK: return + # CHECK: } + # CHECK: return + # CHECK: } + # CHECK: } + print(module) + + value2.owner.detach_from_parent() + # CHECK: %0 + print(value2.get_name()) -- 2.34.1