Revert "[MLIR][python bindings] implement `replace_all_uses_with` on `PyValue`"
authormax <maksim.levental@gmail.com>
Tue, 25 Apr 2023 20:32:14 +0000 (15:32 -0500)
committermax <maksim.levental@gmail.com>
Tue, 25 Apr 2023 20:45:17 +0000 (15:45 -0500)
This reverts commit 3bab7cb089d92cc7025ebc57ef3a74d3ce94ecd8 because it breaks sanitizers.

Differential Revision: https://reviews.llvm.org/D149188

mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/value.py

index b45b955..84d226b 100644 (file)
@@ -755,12 +755,6 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
 /// operand if there are no uses.
 MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
 
-/// Replace all uses of 'of' value with the 'with' value, updating anything in
-/// the IR that uses 'of' to use the other value instead.  When this returns
-/// there are zero uses of 'of'.
-MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
-                                                      MlirValue with);
-
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
index 81c5cd2..f3fd386 100644 (file)
 
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
+//#include "mlir-c/Registration.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -152,11 +154,6 @@ position in the argument list. If the value is an operation result, this is
 equivalent to printing the operation that produced it.
 )";
 
-static const char kValueReplaceAllUsesWithDocstring[] =
-    R"(Replace all uses of value with the new value, updating anything in
-the IR that uses 'self' to use the other value instead.
-)";
-
 //------------------------------------------------------------------------------
 // Utilities.
 //------------------------------------------------------------------------------
@@ -3319,18 +3316,10 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           kValueDunderStrDocstring)
-      .def_property_readonly("type",
-                             [](PyValue &self) {
-                               return PyType(
-                                   self.getParentOperation()->getContext(),
-                                   mlirValueGetType(self.get()));
-                             })
-      .def(
-          "replace_all_uses_with",
-          [](PyValue &self, PyValue &with) {
-            mlirValueReplaceAllUsesOfWith(self.get(), with.get());
-          },
-          kValueReplaceAllUsesWithDocstring);
+      .def_property_readonly("type", [](PyValue &self) {
+        return PyType(self.getParentOperation()->getContext(),
+                      mlirValueGetType(self.get()));
+      });
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
   PyOpOperand::bind(m);
index 0bbcb30..051559a 100644 (file)
@@ -751,10 +751,6 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
   return wrap(opOperand);
 }
 
-void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
-  unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
-}
-
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
index bcde60a..5f205c4 100644 (file)
 #include <stdlib.h>
 #include <string.h>
 
-MlirValue makeConstantLiteral(MlirContext ctx, const char *literalStr,
-                              const char *typeStr) {
-  MlirLocation loc = mlirLocationUnknownGet(ctx);
-  char attrStr[50];
-  sprintf(attrStr, "%s : %s", literalStr, typeStr);
-  MlirAttribute literal =
-      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(attrStr));
-  MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
-      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), literal);
-  MlirOperationState constState = mlirOperationStateGet(
-      mlirStringRefCreateFromCString("arith.constant"), loc);
-  MlirType type =
-      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(typeStr));
-  mlirOperationStateAddResults(&constState, 1, &type);
-  mlirOperationStateAddAttributes(&constState, 1, &valueAttr);
-  MlirOperation constOp = mlirOperationCreate(&constState);
-  return mlirOperationGetResult(constOp, 0);
-}
-
 static void registerAllUpstreamDialects(MlirContext ctx) {
   MlirDialectRegistry registry = mlirDialectRegistryCreate();
   mlirRegisterAllDialects(registry);
@@ -134,17 +115,26 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
   MlirOperation func = mlirOperationCreate(&funcState);
   mlirBlockInsertOwnedOperation(moduleBody, 0, func);
 
-  MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
-  MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
+  MlirType indexType =
+      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), location);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
   mlirBlockAppendOwnedOperation(funcBody, constZero);
 
   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
+  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
   MlirValue dimOperands[] = {funcArg0, constZeroValue};
   MlirOperationState dimState = mlirOperationStateGet(
       mlirStringRefCreateFromCString("memref.dim"), location);
   mlirOperationStateAddOperands(&dimState, 2, dimOperands);
-  MlirType indexType =
-      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
   mlirOperationStateAddResults(&dimState, 1, &indexType);
   MlirOperation dim = mlirOperationCreate(&dimState);
   mlirBlockAppendOwnedOperation(funcBody, dim);
@@ -163,11 +153,11 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
       mlirStringRefCreateFromCString("arith.constant"), location);
   mlirOperationStateAddResults(&constOneState, 1, &indexType);
   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
-  MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
-  MlirOperation constOne = mlirOpResultGetOwner(constOneValue);
+  MlirOperation constOne = mlirOperationCreate(&constOneState);
   mlirBlockAppendOwnedOperation(funcBody, constOne);
 
   MlirValue dimValue = mlirOperationGetResult(dim, 0);
+  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
   MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
   MlirOperationState loopState = mlirOperationStateGet(
       mlirStringRefCreateFromCString("scf.for"), location);
@@ -830,6 +820,11 @@ static int printBuiltinTypes(MlirContext ctx) {
   return 0;
 }
 
+void callbackSetFixedLengthString(const char *data, intptr_t len,
+                                  void *userData) {
+  strncpy(userData, data, len);
+}
+
 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
   if (strlen(lhs) != rhs.length) {
     return false;
@@ -1799,10 +1794,32 @@ int testOperands(void) {
   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
   MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirType indexType = mlirIndexTypeGet(ctx);
 
   // Create some constants to use as operands.
-  MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
-  MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), loc);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
+  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
+
+  MlirAttribute indexOneLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
+  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexOneLiteral);
+  MlirOperationState constOneState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), loc);
+  mlirOperationStateAddResults(&constOneState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
+  MlirOperation constOne = mlirOperationCreate(&constOneState);
+  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
 
   // Create the operation under test.
   mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -1856,50 +1873,9 @@ int testOperands(void) {
     return 3;
   }
 
-  MlirOperationState op2State =
-      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
-  MlirValue initialOperands2[] = {constOneValue};
-  mlirOperationStateAddOperands(&op2State, 1, initialOperands2);
-  MlirOperation op2 = mlirOperationCreate(&op2State);
-
-  MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue);
-  fprintf(stderr, "First use owner: ");
-  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: First use owner: "dummy.op2"
-
-  use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue));
-  fprintf(stderr, "Second use owner: ");
-  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: Second use owner: "dummy.op"
-
-  MlirValue constTwoValue = makeConstantLiteral(ctx, "2", "index");
-  mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue);
-
-  use3 = mlirValueGetFirstUse(constOneValue);
-  if (!mlirOpOperandIsNull(use3)) {
-    fprintf(stderr, "ERROR: Use should be null\n");
-    return 4;
-  }
-
-  MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue);
-  fprintf(stderr, "First replacement use owner: ");
-  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: First replacement use owner: "dummy.op"
-
-  use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue));
-  fprintf(stderr, "Second replacement use owner: ");
-  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: Second replacement use owner: "dummy.op2"
-
   mlirOperationDestroy(op);
-  mlirOperationDestroy(op2);
-  mlirOperationDestroy(mlirOpResultGetOwner(constZeroValue));
-  mlirOperationDestroy(mlirOpResultGetOwner(constOneValue));
-  mlirOperationDestroy(mlirOpResultGetOwner(constTwoValue));
+  mlirOperationDestroy(constZero);
+  mlirOperationDestroy(constOne);
   mlirContextDestroy(ctx);
 
   return 0;
@@ -1914,10 +1890,19 @@ int testClone(void) {
   registerAllUpstreamDialects(ctx);
 
   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirType indexType = mlirIndexTypeGet(ctx);
   MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
 
-  MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
-  MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), loc);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
 
   MlirAttribute indexOneLiteral =
       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
@@ -1995,10 +1980,19 @@ int testTypeID(MlirContext ctx) {
   }
 
   MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirType indexType = mlirIndexTypeGet(ctx);
+  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
 
   // Create a registered operation, which should have a type id.
-  MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
-  MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("arith.constant"), loc);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
 
   if (!mlirOperationVerify(constZero)) {
     fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
index 90fe64a..98f55de 100644 (file)
@@ -111,29 +111,3 @@ def testValueUses():
     assert use.owner in [op1, op2]
     print(f"Use owner: {use.owner}")
     print(f"Use operand_number: {use.operand_number}")
-
-# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
-@run
-def testValueReplaceAllUsesWith():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      value = Operation.create("custom.op1", results=[i32]).results[0]
-      op1 = Operation.create("custom.op2", operands=[value])
-      op2 = Operation.create("custom.op2", operands=[value])
-      value2 = Operation.create("custom.op3", results=[i32]).results[0]
-      value.replace_all_uses_with(value2)
-
-  assert len(list(value.uses)) == 0
-
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  for use in value2.uses:
-    assert use.owner in [op1, op2]
-    print(f"Use owner: {use.owner}")
-    print(f"Use operand_number: {use.operand_number}")