[MLIR][python bindings] Add `PyValue.print_as_operand` (`Value::printAsOperand`)
authormax <maksim.levental@gmail.com>
Sun, 7 May 2023 23:19:46 +0000 (18:19 -0500)
committermax <maksim.levental@gmail.com>
Mon, 8 May 2023 15:41:35 +0000 (10:41 -0500)
Useful for easier debugging (no need to regex out all of the stuff around the id).

Differential Revision: https://reviews.llvm.org/D149902

mlir/include/mlir-c/IR.h
mlir/include/mlir/IR/Value.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/test/python/ir/value.py

index 90af14461e29e578ac0b14876c9d9871b5b9cb22..13a3cb0130cea275db4b55d380ab0326df041265 100644 (file)
@@ -776,6 +776,12 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value);
 MLIR_CAPI_EXPORTED void
 mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
 
+/// Prints a value as an operand (i.e., the ValueID).
+MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
+                                                MlirOpPrintingFlags flags,
+                                                MlirStringCallback callback,
+                                                void *userData);
+
 /// Returns an op operand representing the first use of the value, or a null op
 /// operand if there are no uses.
 MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
index 64ced152839f03221bed8e8b706be9fbcff646b5..7a8aee29ca445eee40bb3ef9c117b11a6dbb6a46 100644 (file)
@@ -226,6 +226,7 @@ public:
 
   /// Print this value as if it were an operand.
   void printAsOperand(raw_ostream &os, AsmState &state);
+  void printAsOperand(raw_ostream &os, const OpPrintingFlags &flags);
 
   /// Methods for supporting PointerLikeTypeTraits.
   void *getAsOpaquePointer() const { return impl; }
index f2e188e78c4a7b1c401e9dcb0f4b4411a5ef9575..7ffa464009fc8e0b04cbc39d107806e18c11b3ee 100644 (file)
@@ -156,6 +156,10 @@ position in the argument list. If the value is an operation result, this is
 equivalent to printing the operation that produced it.
 )";
 
+static const char kGetNameAsOperand[] =
+    R"(Returns the string form of value as an operand (i.e., the ValueID).
+)";
+
 static const char kValueReplaceAllUsesWithDocstring[] =
     R"(Replace all uses of value with the new value, updating anything in
 the IR that uses 'self' to use the other value instead.
@@ -3336,6 +3340,19 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           kValueDunderStrDocstring)
+      .def(
+          "get_name",
+          [](PyValue &self, bool useLocalScope) {
+            PyPrintAccumulator printAccum;
+            MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+            if (useLocalScope)
+              mlirOpPrintingFlagsUseLocalScope(flags);
+            mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
+                                    printAccum.getUserData());
+            mlirOpPrintingFlagsDestroy(flags);
+            return printAccum.join();
+          },
+          py::arg("use_local_scope") = false, kGetNameAsOperand)
       .def_property_readonly("type",
                              [](PyValue &self) {
                                return PyType(
index 6ed32e1ce8656304df0844c68e1407f72291a1d7..79386dedfdd98dac2118e994ce3e25fb0bd1574e 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
@@ -767,6 +768,13 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
   unwrap(value).print(stream);
 }
 
+void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
+                             MlirStringCallback callback, void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  Value cppValue = unwrap(value);
+  cppValue.printAsOperand(stream, *unwrap(flags));
+}
+
 MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
   Value cppValue = unwrap(value);
   if (cppValue.use_empty())
index 3c525a6e76351fca42b042390f4d30539ef194ab..3afafd6d4cdf0e090d050558275a9ccaa3c54278 100644 (file)
@@ -44,8 +44,8 @@
 #include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/Threading.h"
 
-#include <tuple>
 #include <optional>
+#include <tuple>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -3673,10 +3673,7 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
                                                  os);
 }
 
-void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
-  // Find the operation to number from based upon the provided flags.
-  Operation *op = this;
-  bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
+static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
   do {
     // If we are printing local scope, stop at the first operation that is
     // isolated from above.
@@ -3689,7 +3686,28 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
       break;
     op = parentOp;
   } while (true);
+  return op;
+}
+
+void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
+  Operation *op;
+  if (auto result = dyn_cast<OpResult>()) {
+    op = result.getOwner();
+  } else {
+    op = cast<BlockArgument>().getOwner()->getParentOp();
+    if (!op) {
+      os << "<<UNKNOWN SSA VALUE>>";
+      return;
+    }
+  }
+  op = findParent(op, flags.shouldUseLocalScope());
+  AsmState state(op, flags);
+  printAsOperand(os, state);
+}
 
+void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
+  // Find the operation to number from based upon the provided flags.
+  Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
   AsmState state(op, printerFlags);
   print(os, state);
 }
index 90fe64ac1762a911c7deba061b72a5d71f0b215d..66568c426216abb3d35b4d9177ecd94ed834a5b1 100644 (file)
@@ -2,6 +2,7 @@
 
 import gc
 from mlir.ir import *
+from mlir.dialects import func
 
 
 def run(f):
@@ -90,6 +91,7 @@ def testValueHash():
   assert hash(block.arguments[0]) == hash(op.operands[0])
   assert hash(op.result) == hash(ret.operands[0])
 
+
 # CHECK-LABEL: TEST: testValueUses
 @run
 def testValueUses():
@@ -112,6 +114,7 @@ def testValueUses():
     print(f"Use owner: {use.owner}")
     print(f"Use operand_number: {use.operand_number}")
 
+
 # CHECK-LABEL: TEST: testValueReplaceAllUsesWith
 @run
 def testValueReplaceAllUsesWith():
@@ -137,3 +140,95 @@ def testValueReplaceAllUsesWith():
     assert use.owner in [op1, op2]
     print(f"Use owner: {use.owner}")
     print(f"Use operand_number: {use.operand_number}")
+
+
+# CHECK-LABEL: TEST: testValuePrintAsOperand
+@run
+def testValuePrintAsOperand():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signless(32)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      value = Operation.create("custom.op1", results=[i32]).results[0]
+      # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+      print(value)
+
+      value2 = Operation.create("custom.op2", results=[i32]).results[0]
+      # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
+      print(value2)
+
+      f = func.FuncOp("test", ([i32, i32], []))
+      entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+
+      with InsertionPoint(entry_block1):
+        value3 = Operation.create("custom.op3", results=[i32]).results[0]
+        # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
+        print(value3)
+        value4 = Operation.create("custom.op4", results=[i32]).results[0]
+        # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
+        print(value4)
+
+        f = func.FuncOp("test", ([i32, i32], []))
+        entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+        with InsertionPoint(entry_block2):
+          value5 = Operation.create("custom.op5", results=[i32]).results[0]
+          # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
+          print(value5)
+          value6 = Operation.create("custom.op6", results=[i32]).results[0]
+          # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
+          print(value6)
+
+          func.ReturnOp([])
+
+        func.ReturnOp([])
+
+    # CHECK: %[[VAL1]]
+    print(value.get_name())
+    # CHECK: %[[VAL2]]
+    print(value2.get_name())
+    # CHECK: %[[VAL3]]
+    print(value3.get_name())
+    # CHECK: %[[VAL4]]
+    print(value4.get_name())
+
+    # CHECK: %0
+    print(value3.get_name(use_local_scope=True))
+    # CHECK: %1
+    print(value4.get_name(use_local_scope=True))
+
+    # CHECK: %[[VAL5]]
+    print(value5.get_name())
+    # CHECK: %[[VAL6]]
+    print(value6.get_name())
+
+    # CHECK: %[[ARG0:.*]]
+    print(entry_block1.arguments[0].get_name())
+    # CHECK: %[[ARG1:.*]]
+    print(entry_block1.arguments[1].get_name())
+
+    # CHECK: %[[ARG2:.*]]
+    print(entry_block2.arguments[0].get_name())
+    # CHECK: %[[ARG3:.*]]
+    print(entry_block2.arguments[1].get_name())
+
+    # CHECK: module {
+    # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
+    # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
+    # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
+    # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
+    # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
+    # CHECK:     func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
+    # CHECK:       %[[VAL5]] = "custom.op5"() : () -> i32
+    # CHECK:       %[[VAL6]] = "custom.op6"() : () -> i32
+    # CHECK:       return
+    # CHECK:     }
+    # CHECK:     return
+    # CHECK:   }
+    # CHECK: }
+    print(module)
+
+    value2.owner.detach_from_parent()
+    # CHECK: %0
+    print(value2.get_name())