size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
py::object PyMlirContext::createOperation(
std::string name, PyLocation location,
llvm::Optional<std::vector<PyType *>> results,
// PyModule
//------------------------------------------------------------------------------
-PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- py::object pyRef =
- py::cast(unownedModule, py::return_value_policy::take_ownership);
- unownedModule->handle = pyRef;
- return PyModuleRef(unownedModule, std::move(pyRef));
+PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
+ : BaseContextObject(std::move(contextRef)), module(module) {}
+
+PyModule::~PyModule() {
+ py::gil_scoped_acquire acquire;
+ auto &liveModules = getContext()->liveModules;
+ assert(liveModules.count(module.ptr) == 1 &&
+ "destroying module not in live map");
+ liveModules.erase(module.ptr);
+ mlirModuleDestroy(module);
+}
+
+PyModuleRef PyModule::forModule(MlirModule module) {
+ MlirContext context = mlirModuleGetContext(module);
+ PyMlirContextRef contextRef = PyMlirContext::forContext(context);
+
+ py::gil_scoped_acquire acquire;
+ auto &liveModules = contextRef->liveModules;
+ auto it = liveModules.find(module.ptr);
+ if (it == liveModules.end()) {
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedModule, py::return_value_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ liveModules[module.ptr] =
+ std::make_pair(unownedModule->handle, unownedModule);
+ return PyModuleRef(unownedModule, std::move(pyRef));
+ }
+ // Use existing.
+ PyModule *existing = it->second.second;
+ py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ return PyModuleRef(existing, std::move(pyRef));
+}
+
+py::object PyModule::createFromCapsule(py::object capsule) {
+ MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
+ if (mlirModuleIsNull(rawModule))
+ throw py::error_already_set();
+ return forModule(rawModule).releaseObject();
}
py::object PyModule::getCapsule() {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
- return PyModule::create(self.getRef(), module).releaseObject();
+ return PyModule::forModule(module).releaseObject();
},
kContextParseDocstring)
.def(
+ "create_module",
+ [](PyMlirContext &self, PyLocation &loc) {
+ MlirModule module = mlirModuleCreateEmpty(loc.loc);
+ return PyModule::forModule(module).releaseObject();
+ },
+ py::arg("loc"), "Creates an empty module")
+ .def(
"parse_attr",
[](PyMlirContext &self, std::string attrSpec) {
MlirAttribute type =
kContextGetFileLocationDocstring, py::arg("filename"),
py::arg("line"), py::arg("col"));
- py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self.loc, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- });
+ py::class_<PyLocation>(m, "Location")
+ .def_property_readonly(
+ "context",
+ [](PyLocation &self) { return self.getContext().getObject(); },
+ "Context that owns the Location")
+ .def("__repr__", [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self.loc, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ });
// Mapping of Module
py::class_<PyModule>(m, "Module")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+ .def_property_readonly(
+ "context",
+ [](PyModule &self) { return self.getContext().getObject(); },
+ "Context that created the Module")
.def_property_readonly(
"operation",
[](PyModule &self) {
// Mapping of Operation.
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
+ "context",
+ [](PyOperation &self) { return self.getContext().getObject(); },
+ "Context that owns the Operation")
+ .def_property_readonly(
"regions",
[](PyOperation &self) { return PyRegionList(self.getRef()); })
.def("__iter__",
// Mapping of Type.
py::class_<PyAttribute>(m, "Attribute")
+ .def_property_readonly(
+ "context",
+ [](PyAttribute &self) { return self.getContext().getObject(); },
+ "Context that owns the Attribute")
.def(
"get_named",
[](PyAttribute &self, std::string name) {
// Mapping of Type.
py::class_<PyType>(m, "Type")
+ .def_property_readonly(
+ "context", [](PyType &self) { return self.getContext().getObject(); },
+ "Context that owns the Type")
.def("__eq__",
[](PyType &self, py::object &other) {
try {
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
- /// may be a pre-existing object.
+ /// may be a pre-existing object. Ownership of the underlying MlirContext
+ /// is taken by calling this function.
static pybind11::object createFromCapsule(pybind11::object capsule);
/// Gets the count of live context objects. Used for testing.
/// Used for testing.
size_t getLiveOperationCount();
+ /// Gets the count of live modules associated with this context.
+ /// Used for testing.
+ size_t getLiveModuleCount();
+
/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();
+ // Interns all live modules associated with this context. Modules tracked
+ // in this map are valid. When a module is invalidated, it is removed
+ // from this map, and while it still exists as an instance, any
+ // attempt to access it will raise an error.
+ using LiveModuleMap =
+ llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
+ LiveModuleMap liveModules;
+
// Interns all live operations associated with this context. Operations
// tracked in this map are valid. When an operation is invalidated, it is
// removed from this map, and while it still exists as an instance, any
LiveOperationMap liveOperations;
MlirContext context;
+ friend class PyModule;
friend class PyOperation;
};
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
- /// Creates a reference to the module
- static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module);
+ /// Returns a PyModule reference for the given MlirModule. This may return
+ /// a pre-existing or new object.
+ static PyModuleRef forModule(MlirModule module);
PyModule(PyModule &) = delete;
- ~PyModule() {
- if (module.ptr)
- mlirModuleDestroy(module);
- }
+ PyModule(PyMlirContext &&) = delete;
+ ~PyModule();
/// Gets the backing MlirModule.
MlirModule get() { return module; }
/// instances, which is not currently done.
pybind11::object getCapsule();
+ /// Creates a PyModule from the MlirModule wrapped by a capsule.
+ /// Note that PyModule instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirModule
+ /// is taken by calling this function.
+ static pybind11::object createFromCapsule(pybind11::object capsule);
+
private:
- PyModule(PyMlirContextRef contextRef, MlirModule module)
- : BaseContextObject(std::move(contextRef)), module(module) {}
+ PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
pybind11::handle handle;
};
def testParseSuccess():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert module.context is ctx
print("CLEAR CONTEXT")
ctx = None # Ensure that module captures the context.
gc.collect()
run(testParseError)
+# Verify successful parse.
+# CHECK-LABEL: TEST: testCreateEmpty
+# CHECK: module {
+def testCreateEmpty():
+ ctx = mlir.ir.Context()
+ loc = ctx.get_unknown_location()
+ module = ctx.create_module(loc)
+ print("CLEAR CONTEXT")
+ ctx = None # Ensure that module captures the context.
+ gc.collect()
+ print(str(module))
+
+run(testCreateEmpty)
+
+
# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
def testModuleOperation():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert ctx._get_live_module_count() == 1
op1 = module.operation
assert ctx._get_live_operation_count() == 1
# CHECK: module @successfulParse
gc.collect()
print("LIVE OPERATIONS:", ctx._get_live_operation_count())
assert ctx._get_live_operation_count() == 0
+ assert ctx._get_live_module_count() == 0
run(testModuleOperation)
def testModuleCapsule():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
- print(module._CAPIPtr)
+ module_capsule = module._CAPIPtr
+ print(module_capsule)
+ module_dup = mlir.ir.Module._CAPICreate(module_capsule)
+ assert module is module_dup
+ assert module_dup.context is ctx
+ # Gc and verify destructed.
+ module = None
+ module_capsule = None
+ module_dup = None
+ gc.collect()
+ assert ctx._get_live_module_count() == 0
+
run(testModuleCapsule)