[mlir][python] Add `destroy` method to PyOperation.
authorMike Urbach <mikeurbach@gmail.com>
Sat, 24 Apr 2021 02:54:04 +0000 (20:54 -0600)
committerMike Urbach <mikeurbach@gmail.com>
Thu, 29 Apr 2021 01:30:05 +0000 (19:30 -0600)
This adds a method to directly invoke `mlirOperationDestroy` on the
MlirOperation wrapped by a PyOperation.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D101422

mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/Bindings/Python/ir_operation.py

index 781e9ae..160e35b 100644 (file)
@@ -753,6 +753,9 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
     : BaseContextObject(std::move(contextRef)), operation(operation) {}
 
 PyOperation::~PyOperation() {
+  // If the operation has already been invalidated there is nothing to do.
+  if (!valid)
+    return;
   auto &liveOperations = getContext()->liveOperations;
   assert(liveOperations.count(operation.ptr) == 1 &&
          "destroying operation not in live map");
@@ -869,6 +872,7 @@ py::object PyOperationBase::getAsm(bool binary,
 }
 
 PyOperationRef PyOperation::getParentOperation() {
+  checkValid();
   if (!isAttached())
     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
   MlirOperation operation = mlirOperationGetParentOperation(get());
@@ -878,6 +882,7 @@ PyOperationRef PyOperation::getParentOperation() {
 }
 
 PyBlock PyOperation::getBlock() {
+  checkValid();
   PyOperationRef parentOperation = getParentOperation();
   MlirBlock block = mlirOperationGetBlock(get());
   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
@@ -885,6 +890,7 @@ PyBlock PyOperation::getBlock() {
 }
 
 py::object PyOperation::getCapsule() {
+  checkValid();
   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
 }
 
@@ -1032,6 +1038,7 @@ py::object PyOperation::create(
 }
 
 py::object PyOperation::createOpView() {
+  checkValid();
   MlirIdentifier ident = mlirOperationGetName(get());
   MlirStringRef identStr = mlirIdentifierStr(ident);
   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
@@ -1041,6 +1048,18 @@ py::object PyOperation::createOpView() {
   return py::cast(PyOpView(getRef().getObject()));
 }
 
+void PyOperation::erase() {
+  checkValid();
+  // TODO: Fix memory hazards when erasing a tree of operations for which a deep
+  // Python reference to a child operation is live. All children should also
+  // have their `valid` bit set to false.
+  auto &liveOperations = getContext()->liveOperations;
+  if (liveOperations.count(operation.ptr))
+    liveOperations.erase(operation.ptr);
+  mlirOperationDestroy(operation);
+  valid = false;
+}
+
 //------------------------------------------------------------------------------
 // PyOpView
 //------------------------------------------------------------------------------
@@ -2094,11 +2113,13 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
                   kOperationCreateDocstring)
+      .def("erase", &PyOperation::erase)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyOperation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
       .def_property_readonly("name",
                              [](PyOperation &self) {
+                               self.checkValid();
                                MlirOperation operation = self.get();
                                MlirStringRef name = mlirIdentifierStr(
                                    mlirOperationGetName(operation));
@@ -2106,7 +2127,10 @@ void mlir::python::populateIRCore(py::module &m) {
                              })
       .def_property_readonly(
           "context",
-          [](PyOperation &self) { return self.getContext().getObject(); },
+          [](PyOperation &self) {
+            self.checkValid();
+            return self.getContext().getObject();
+          },
           "Context that owns the Operation")
       .def_property_readonly("opview", &PyOperation::createOpView);
 
index 292080d..79c480e 100644 (file)
@@ -473,6 +473,10 @@ public:
   /// Creates an OpView suitable for this operation.
   pybind11::object createOpView();
 
+  /// Erases the underlying MlirOperation, removes its pointer from the
+  /// parent context's live operations map, and sets the valid bit false.
+  void erase();
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,
index 746cd3e..83e4a4f 100644 (file)
@@ -646,3 +646,25 @@ def testCapsuleConversions():
     assert m2 is m
 
 run(testCapsuleConversions)
+
+# CHECK-LABEL: TEST: testOperationErase
+def testOperationErase():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    m = Module.create()
+    with InsertionPoint(m.body):
+      op = Operation.create("custom.op1")
+
+      # CHECK: "custom.op1"
+      print(m)
+
+      op.operation.erase()
+
+      # CHECK-NOT: "custom.op1"
+      print(m)
+
+      # Ensure we can create another operation
+      Operation.create("custom.op2")
+
+run(testOperationErase)