[MLIR] [Python] Add `owner` to PyValue and fix its parent reference
authorJohn Demme <john.demme@microsoft.com>
Thu, 15 Jul 2021 03:19:27 +0000 (20:19 -0700)
committerJohn Demme <john.demme@microsoft.com>
Thu, 15 Jul 2021 03:32:43 +0000 (20:32 -0700)
Adds `owner` python call to `mlir.ir.Value`.

Assuming that `PyValue.parentOperation` is intended to be the value's owner, this fixes the construction of it from `PyOpOperandList`.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D103853

mlir/lib/Bindings/Python/IRCore.cpp

index 3f08522..b5197d9 100644 (file)
@@ -1652,7 +1652,17 @@ public:
   }
 
   PyValue getElement(intptr_t pos) {
-    return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
+    MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+    MlirOperation owner;
+    if (mlirValueIsAOpResult(operand))
+      owner = mlirOpResultGetOwner(operand);
+    else if (mlirValueIsABlockArgument(operand))
+      owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+    else
+      assert(false && "Value must be an block arg or op result.");
+    PyOperationRef pyOwner =
+        PyOperation::forOperation(operation->getContext(), owner);
+    return PyValue(pyOwner, operand);
   }
 
   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
@@ -2429,6 +2439,15 @@ void mlir::python::populateIRCore(py::module &m) {
       .def(
           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
           kDumpDocstring)
+      .def_property_readonly(
+          "owner",
+          [](PyValue &self) {
+            assert(mlirOperationEqual(self.getParentOperation()->get(),
+                                      mlirOpResultGetOwner(self.get())) &&
+                   "expected the owner of the value in Python to match that in "
+                   "the IR");
+            return self.getParentOperation().getObject();
+          })
       .def("__eq__",
            [](PyValue &self, PyValue &other) {
              return self.get().ptr == other.get().ptr;