[mlir][Python] Add missing capsule->module and Context.create_module.
authorStella Laurenzo <stellaraccident@gmail.com>
Tue, 13 Oct 2020 04:19:13 +0000 (21:19 -0700)
committerStella Laurenzo <stellaraccident@gmail.com>
Tue, 13 Oct 2020 20:10:33 +0000 (13:10 -0700)
* Extends Context/Operation interning to cover Module as well.
* Implements Module.context, Attribute.context, Type.context, and Location.context back-references (facilitated testing and also on the TODO list).
* Adds method to create an empty Module.
* Discovered missing in npcomp.

Differential Revision: https://reviews.llvm.org/D89294

mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/test/Bindings/Python/ir_attributes.py
mlir/test/Bindings/Python/ir_location.py
mlir/test/Bindings/Python/ir_module.py
mlir/test/Bindings/Python/ir_operation.py
mlir/test/Bindings/Python/ir_types.py

index 24b2a8b..acb168c 100644 (file)
@@ -86,6 +86,16 @@ inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
   return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL);
 }
 
+/** Extracts an MlirModule from a capsule as produced from
+ * mlirPythonModuleToCapsule. If the capsule is not of the right type, then
+ * a null module is returned (as checked via mlirModuleIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE);
+  MlirModule module = {ptr};
+  return module;
+}
+
 #ifdef __cplusplus
 }
 #endif
index 36e25ee..8f525e8 100644 (file)
@@ -497,6 +497,8 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
 py::object PyMlirContext::createOperation(
     std::string name, PyLocation location,
     llvm::Optional<std::vector<PyType *>> results,
@@ -582,15 +584,49 @@ py::object PyMlirContext::createOperation(
 // PyModule
 //------------------------------------------------------------------------------
 
-PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
-  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-  // Note that the default return value policy on cast is automatic_reference,
-  // which does not take ownership (delete will not be called).
-  // Just be explicit.
-  py::object pyRef =
-      py::cast(unownedModule, py::return_value_policy::take_ownership);
-  unownedModule->handle = pyRef;
-  return PyModuleRef(unownedModule, std::move(pyRef));
+PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
+    : BaseContextObject(std::move(contextRef)), module(module) {}
+
+PyModule::~PyModule() {
+  py::gil_scoped_acquire acquire;
+  auto &liveModules = getContext()->liveModules;
+  assert(liveModules.count(module.ptr) == 1 &&
+         "destroying module not in live map");
+  liveModules.erase(module.ptr);
+  mlirModuleDestroy(module);
+}
+
+PyModuleRef PyModule::forModule(MlirModule module) {
+  MlirContext context = mlirModuleGetContext(module);
+  PyMlirContextRef contextRef = PyMlirContext::forContext(context);
+
+  py::gil_scoped_acquire acquire;
+  auto &liveModules = contextRef->liveModules;
+  auto it = liveModules.find(module.ptr);
+  if (it == liveModules.end()) {
+    // Create.
+    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+    // Note that the default return value policy on cast is automatic_reference,
+    // which does not take ownership (delete will not be called).
+    // Just be explicit.
+    py::object pyRef =
+        py::cast(unownedModule, py::return_value_policy::take_ownership);
+    unownedModule->handle = pyRef;
+    liveModules[module.ptr] =
+        std::make_pair(unownedModule->handle, unownedModule);
+    return PyModuleRef(unownedModule, std::move(pyRef));
+  }
+  // Use existing.
+  PyModule *existing = it->second.second;
+  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+  return PyModuleRef(existing, std::move(pyRef));
+}
+
+py::object PyModule::createFromCapsule(py::object capsule) {
+  MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
+  if (mlirModuleIsNull(rawModule))
+    throw py::error_already_set();
+  return forModule(rawModule).releaseObject();
 }
 
 py::object PyModule::getCapsule() {
@@ -1461,6 +1497,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
              return ref.releaseObject();
            })
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -1489,10 +1526,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                   PyExc_ValueError,
                   "Unable to parse module assembly (see diagnostics)");
             }
-            return PyModule::create(self.getRef(), module).releaseObject();
+            return PyModule::forModule(module).releaseObject();
           },
           kContextParseDocstring)
       .def(
+          "create_module",
+          [](PyMlirContext &self, PyLocation &loc) {
+            MlirModule module = mlirModuleCreateEmpty(loc.loc);
+            return PyModule::forModule(module).releaseObject();
+          },
+          py::arg("loc"), "Creates an empty module")
+      .def(
           "parse_attr",
           [](PyMlirContext &self, std::string attrSpec) {
             MlirAttribute type =
@@ -1538,16 +1582,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           kContextGetFileLocationDocstring, py::arg("filename"),
           py::arg("line"), py::arg("col"));
 
-  py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
-    PyPrintAccumulator printAccum;
-    mlirLocationPrint(self.loc, printAccum.getCallback(),
-                      printAccum.getUserData());
-    return printAccum.join();
-  });
+  py::class_<PyLocation>(m, "Location")
+      .def_property_readonly(
+          "context",
+          [](PyLocation &self) { return self.getContext().getObject(); },
+          "Context that owns the Location")
+      .def("__repr__", [](PyLocation &self) {
+        PyPrintAccumulator printAccum;
+        mlirLocationPrint(self.loc, printAccum.getCallback(),
+                          printAccum.getUserData());
+        return printAccum.join();
+      });
 
   // Mapping of Module
   py::class_<PyModule>(m, "Module")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def_property_readonly(
+          "context",
+          [](PyModule &self) { return self.getContext().getObject(); },
+          "Context that created the Module")
       .def_property_readonly(
           "operation",
           [](PyModule &self) {
@@ -1577,6 +1631,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of Operation.
   py::class_<PyOperation>(m, "Operation")
       .def_property_readonly(
+          "context",
+          [](PyOperation &self) { return self.getContext().getObject(); },
+          "Context that owns the Operation")
+      .def_property_readonly(
           "regions",
           [](PyOperation &self) { return PyRegionList(self.getRef()); })
       .def("__iter__",
@@ -1657,6 +1715,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Mapping of Type.
   py::class_<PyAttribute>(m, "Attribute")
+      .def_property_readonly(
+          "context",
+          [](PyAttribute &self) { return self.getContext().getObject(); },
+          "Context that owns the Attribute")
       .def(
           "get_named",
           [](PyAttribute &self, std::string name) {
@@ -1737,6 +1799,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Mapping of Type.
   py::class_<PyType>(m, "Type")
+      .def_property_readonly(
+          "context", [](PyType &self) { return self.getContext().getObject(); },
+          "Context that owns the Type")
       .def("__eq__",
            [](PyType &self, py::object &other) {
              try {
index e67142e..c175018 100644 (file)
@@ -113,7 +113,8 @@ public:
 
   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
   /// Note that PyMlirContext instances are uniqued, so the returned object
-  /// may be a pre-existing object.
+  /// may be a pre-existing object. Ownership of the underlying MlirContext
+  /// is taken by calling this function.
   static pybind11::object createFromCapsule(pybind11::object capsule);
 
   /// Gets the count of live context objects. Used for testing.
@@ -123,6 +124,10 @@ public:
   /// Used for testing.
   size_t getLiveOperationCount();
 
+  /// Gets the count of live modules associated with this context.
+  /// Used for testing.
+  size_t getLiveModuleCount();
+
   /// Creates an operation. See corresponding python docstring.
   pybind11::object
   createOperation(std::string name, PyLocation location,
@@ -142,6 +147,14 @@ private:
   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
   static LiveContextMap &getLiveContexts();
 
+  // Interns all live modules associated with this context. Modules tracked
+  // in this map are valid. When a module is invalidated, it is removed
+  // from this map, and while it still exists as an instance, any
+  // attempt to access it will raise an error.
+  using LiveModuleMap =
+      llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
+  LiveModuleMap liveModules;
+
   // Interns all live operations associated with this context. Operations
   // tracked in this map are valid. When an operation is invalidated, it is
   // removed from this map, and while it still exists as an instance, any
@@ -151,6 +164,7 @@ private:
   LiveOperationMap liveOperations;
 
   MlirContext context;
+  friend class PyModule;
   friend class PyOperation;
 };
 
@@ -186,13 +200,12 @@ class PyModule;
 using PyModuleRef = PyObjectRef<PyModule>;
 class PyModule : public BaseContextObject {
 public:
-  /// Creates a reference to the module
-  static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module);
+  /// Returns a PyModule reference for the given MlirModule. This may return
+  /// a pre-existing or new object.
+  static PyModuleRef forModule(MlirModule module);
   PyModule(PyModule &) = delete;
-  ~PyModule() {
-    if (module.ptr)
-      mlirModuleDestroy(module);
-  }
+  PyModule(PyMlirContext &&) = delete;
+  ~PyModule();
 
   /// Gets the backing MlirModule.
   MlirModule get() { return module; }
@@ -209,9 +222,14 @@ public:
   /// instances, which is not currently done.
   pybind11::object getCapsule();
 
+  /// Creates a PyModule from the MlirModule wrapped by a capsule.
+  /// Note that PyModule instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirModule
+  /// is taken by calling this function.
+  static pybind11::object createFromCapsule(pybind11::object capsule);
+
 private:
-  PyModule(PyMlirContextRef contextRef, MlirModule module)
-      : BaseContextObject(std::move(contextRef)), module(module) {}
+  PyModule(PyMlirContextRef contextRef, MlirModule module);
   MlirModule module;
   pybind11::handle handle;
 };
index dfdc819..bf99a76 100644 (file)
@@ -14,6 +14,7 @@ def run(f):
 def testParsePrint():
   ctx = mlir.ir.Context()
   t = ctx.parse_attr('"hello"')
+  assert t.context is ctx
   ctx = None
   gc.collect()
   # CHECK: "hello"
index ac42c61..f7e9924 100644 (file)
@@ -14,6 +14,7 @@ def run(f):
 def testUnknown():
   ctx = mlir.ir.Context()
   loc = ctx.get_unknown_location()
+  assert loc.context is ctx
   ctx = None
   gc.collect()
   # CHECK: unknown str: loc(unknown)
index d85a415..5f34038 100644 (file)
@@ -16,6 +16,7 @@ def run(f):
 def testParseSuccess():
   ctx = mlir.ir.Context()
   module = ctx.parse_module(r"""module @successfulParse {}""")
+  assert module.context is ctx
   print("CLEAR CONTEXT")
   ctx = None  # Ensure that module captures the context.
   gc.collect()
@@ -40,6 +41,21 @@ def testParseError():
 run(testParseError)
 
 
+# Verify successful parse.
+# CHECK-LABEL: TEST: testCreateEmpty
+# CHECK: module {
+def testCreateEmpty():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_unknown_location()
+  module = ctx.create_module(loc)
+  print("CLEAR CONTEXT")
+  ctx = None  # Ensure that module captures the context.
+  gc.collect()
+  print(str(module))
+
+run(testCreateEmpty)
+
+
 # Verify round-trip of ASM that contains unicode.
 # Note that this does not test that the print path converts unicode properly
 # because MLIR asm always normalizes it to the hex encoding.
@@ -61,6 +77,7 @@ run(testRoundtripUnicode)
 def testModuleOperation():
   ctx = mlir.ir.Context()
   module = ctx.parse_module(r"""module @successfulParse {}""")
+  assert ctx._get_live_module_count() == 1
   op1 = module.operation
   assert ctx._get_live_operation_count() == 1
   # CHECK: module @successfulParse
@@ -82,6 +99,7 @@ def testModuleOperation():
   gc.collect()
   print("LIVE OPERATIONS:", ctx._get_live_operation_count())
   assert ctx._get_live_operation_count() == 0
+  assert ctx._get_live_module_count() == 0
 
 run(testModuleOperation)
 
@@ -90,7 +108,19 @@ run(testModuleOperation)
 def testModuleCapsule():
   ctx = mlir.ir.Context()
   module = ctx.parse_module(r"""module @successfulParse {}""")
+  assert ctx._get_live_module_count() == 1
   # CHECK: "mlir.ir.Module._CAPIPtr"
-  print(module._CAPIPtr)
+  module_capsule = module._CAPIPtr
+  print(module_capsule)
+  module_dup = mlir.ir.Module._CAPICreate(module_capsule)
+  assert module is module_dup
+  assert module_dup.context is ctx
+  # Gc and verify destructed.
+  module = None
+  module_capsule = None
+  module_dup = None
+  gc.collect()
+  assert ctx._get_live_module_count() == 0
+
 
 run(testModuleCapsule)
index 881398e..37b8305 100644 (file)
@@ -23,6 +23,7 @@ def testTraverseOpRegionBlockIterators():
     }
   """)
   op = module.operation
+  assert op.context is ctx
   # Get the block using iterators off of the named collections.
   regions = list(op.regions)
   blocks = list(regions[0].blocks)
index d8ae77f..5a9c5a1 100644 (file)
@@ -14,6 +14,7 @@ def run(f):
 def testParsePrint():
   ctx = mlir.ir.Context()
   t = ctx.parse_type("i32")
+  assert t.context is ctx
   ctx = None
   gc.collect()
   # CHECK: i32