Add a `mlirModuleGetBody()` accessor to the C API and bind it in Python
authorMehdi Amini <joker.eph@gmail.com>
Wed, 28 Oct 2020 05:57:17 +0000 (05:57 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 28 Oct 2020 17:53:52 +0000 (17:53 +0000)
Getting the body of a Module is a common need which justifies a
dedicated accessor instead of forcing users to go through the
region->blocks->front unwrapping manually.

Differential Revision: https://reviews.llvm.org/D90287

mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/dialects.py
mlir/test/CAPI/ir.c

index a08fe77da37cde8988f0ee3c23a5834739f73a04..af0ab1fdf34151dee29e170276a8a4658e56108d 100644 (file)
@@ -175,6 +175,9 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 /** Gets the context that a module was created with. */
 MlirContext mlirModuleGetContext(MlirModule module);
 
+/** Gets the body of the module, i.e. the only block it contains. */
+MlirBlock mlirModuleGetBody(MlirModule module);
+
 /** Checks whether a module is null. */
 static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
 
index 2fba7fa5e28359c92a68a988ba4ce0a3f227c839..4a46d9161d76386ae1cf3d8c1e1a21787ab03238 100644 (file)
@@ -2234,6 +2234,16 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                 .releaseObject();
           },
           "Accesses the module as an operation")
+      .def_property_readonly(
+          "body",
+          [](PyModule &self) {
+            PyOperationRef module_op = PyOperation::forOperation(
+                self.getContext(), mlirModuleGetOperation(self.get()),
+                self.getRef().releaseObject());
+            PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
+            return returnBlock;
+          },
+          "Return the block for this module")
       .def(
           "dump",
           [](PyModule &self) {
index fdc40bc6c4f18328dbdfd216e8b5e502824fbabb..f3c91d1fae24c8c624cb27d56f9f841f3bccbb06 100644 (file)
@@ -148,6 +148,10 @@ MlirContext mlirModuleGetContext(MlirModule module) {
   return wrap(unwrap(module).getContext());
 }
 
+MlirBlock mlirModuleGetBody(MlirModule module) {
+  return wrap(unwrap(module).getBody());
+}
+
 void mlirModuleDestroy(MlirModule module) {
   // Transfer ownership to an OwningModuleRef so that its destructor is called.
   OwningModuleRef(unwrap(module));
index bc88e8668f4d55a46c9bb8a985dfd0619ced35f4..ef95163c77432d7637b798ac15fabca9b2eddef6 100644 (file)
@@ -73,7 +73,7 @@ def testCustomOpView():
   f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
   m = ctx.create_module(loc)
-  m_block = m.operation.regions[0].blocks[0]
+  m_block = m.body
   # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
   ip = [0]
   def createInput():
index 1d382c32fb42f99f9d548fa6ea4945feb70b6a3f..87c8b647e6a0f69551a1dd2a2175af6bf3ba4568 100644 (file)
@@ -67,9 +67,7 @@ void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
 
 MlirModule makeAdd(MlirContext ctx, MlirLocation location) {
   MlirModule moduleOp = mlirModuleCreateEmpty(location);
-  MlirOperation module = mlirModuleGetOperation(moduleOp);
-  MlirRegion moduleBodyRegion = mlirOperationGetRegion(module, 0);
-  MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion);
+  MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
 
   MlirType memrefType = mlirTypeParseGet(ctx, "memref<?xf32>");
   MlirType funcBodyArgTypes[] = {memrefType, memrefType};