[MLIR][python bindings] Reimplement `replace_all_uses_with` on `PyValue`
authormax <maksim.levental@gmail.com>
Wed, 26 Apr 2023 14:55:27 +0000 (09:55 -0500)
committermax <maksim.levental@gmail.com>
Wed, 26 Apr 2023 19:04:33 +0000 (14:04 -0500)
Differential Revision: https://reviews.llvm.org/D149261

mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/value.py

index 84d226b..b45b955 100644 (file)
@@ -755,6 +755,12 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
 /// operand if there are no uses.
 MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
 
+/// Replace all uses of 'of' value with the 'with' value, updating anything in
+/// the IR that uses 'of' to use the other value instead.  When this returns
+/// there are zero uses of 'of'.
+MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
+                                                      MlirValue with);
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
index f3fd386..81c5cd2 100644 (file)
 
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/BuiltinAttributes.h"
-#include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
-//#include "mlir-c/Registration.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -154,6 +152,11 @@ position in the argument list. If the value is an operation result, this is
 equivalent to printing the operation that produced it.
 )";
 
+static const char kValueReplaceAllUsesWithDocstring[] =
+    R"(Replace all uses of value with the new value, updating anything in
+the IR that uses 'self' to use the other value instead.
+)";
+
 //------------------------------------------------------------------------------
 // Utilities.
 //------------------------------------------------------------------------------
@@ -3316,10 +3319,18 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           kValueDunderStrDocstring)
-      .def_property_readonly("type", [](PyValue &self) {
-        return PyType(self.getParentOperation()->getContext(),
-                      mlirValueGetType(self.get()));
-      });
+      .def_property_readonly("type",
+                             [](PyValue &self) {
+                               return PyType(
+                                   self.getParentOperation()->getContext(),
+                                   mlirValueGetType(self.get()));
+                             })
+      .def(
+          "replace_all_uses_with",
+          [](PyValue &self, PyValue &with) {
+            mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+          },
+          kValueReplaceAllUsesWithDocstring);
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
   PyOpOperand::bind(m);
index 051559a..0bbcb30 100644 (file)
@@ -751,6 +751,10 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
   return wrap(opOperand);
 }
 
+void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
+  unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
+}
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
index 5f205c4..ca2e036 100644 (file)
@@ -1873,9 +1873,61 @@ int testOperands(void) {
     return 3;
   }
 
+  MlirOperationState op2State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
+  MlirValue initialOperands2[] = {constOneValue};
+  mlirOperationStateAddOperands(&op2State, 1, initialOperands2);
+  MlirOperation op2 = mlirOperationCreate(&op2State);
+
+  MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue);
+  fprintf(stderr, "First use owner: ");
+  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
+  fprintf(stderr, "\n");
+  // CHECK: First use owner: "dummy.op2"
+
+  use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue));
+  fprintf(stderr, "Second use owner: ");
+  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
+  fprintf(stderr, "\n");
+  // CHECK: Second use owner: "dummy.op"
+
+  MlirAttribute indexTwoLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("2 : index"));
+  MlirNamedAttribute indexTwoValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexTwoLiteral);
+  MlirOperationState constTwoState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), loc);
+  mlirOperationStateAddResults(&constTwoState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constTwoState, 1, &indexTwoValueAttr);
+  MlirOperation constTwo = mlirOperationCreate(&constTwoState);
+  MlirValue constTwoValue = mlirOperationGetResult(constTwo, 0);
+
+  mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue);
+
+  use3 = mlirValueGetFirstUse(constOneValue);
+  if (!mlirOpOperandIsNull(use3)) {
+    fprintf(stderr, "ERROR: Use should be null\n");
+    return 4;
+  }
+
+  MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue);
+  fprintf(stderr, "First replacement use owner: ");
+  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
+  fprintf(stderr, "\n");
+  // CHECK: First replacement use owner: "dummy.op"
+
+  use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue));
+  fprintf(stderr, "Second replacement use owner: ");
+  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
+  fprintf(stderr, "\n");
+  // CHECK: Second replacement use owner: "dummy.op2"
+
   mlirOperationDestroy(op);
+  mlirOperationDestroy(op2);
   mlirOperationDestroy(constZero);
   mlirOperationDestroy(constOne);
+  mlirOperationDestroy(constTwo);
   mlirContextDestroy(ctx);
 
   return 0;
index 98f55de..90fe64a 100644 (file)
@@ -111,3 +111,29 @@ def testValueUses():
     assert use.owner in [op1, op2]
     print(f"Use owner: {use.owner}")
     print(f"Use operand_number: {use.operand_number}")
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
+@run
+def testValueReplaceAllUsesWith():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signless(32)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      value = Operation.create("custom.op1", results=[i32]).results[0]
+      op1 = Operation.create("custom.op2", operands=[value])
+      op2 = Operation.create("custom.op2", operands=[value])
+      value2 = Operation.create("custom.op3", results=[i32]).results[0]
+      value.replace_all_uses_with(value2)
+
+  assert len(list(value.uses)) == 0
+
+  # CHECK: Use owner: "custom.op2"
+  # CHECK: Use operand_number: 0
+  # CHECK: Use owner: "custom.op2"
+  # CHECK: Use operand_number: 0
+  for use in value2.uses:
+    assert use.owner in [op1, op2]
+    print(f"Use owner: {use.owner}")
+    print(f"Use operand_number: {use.operand_number}")