Add a `mlirModuleGetBody()` accessor to the C API and bind it in Python
authorMehdi Amini <joker.eph@gmail.com>
Wed, 28 Oct 2020 05:57:17 +0000 (05:57 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 28 Oct 2020 17:53:52 +0000 (17:53 +0000)
Getting the body of a Module is a common need which justifies a
dedicated accessor instead of forcing users to go through the
region->blocks->front unwrapping manually.

Differential Revision: https://reviews.llvm.org/D90287

mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/dialects.py
mlir/test/CAPI/ir.c

index a08fe77..af0ab1f 100644 (file)
@@ -175,6 +175,9 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 /** Gets the context that a module was created with. */
 MlirContext mlirModuleGetContext(MlirModule module);
 
+/** Gets the body of the module, i.e. the only block it contains. */
+MlirBlock mlirModuleGetBody(MlirModule module);
+
 /** Checks whether a module is null. */
 static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
 
index 2fba7fa..4a46d91 100644 (file)
@@ -2234,6 +2234,16 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                 .releaseObject();
           },
           "Accesses the module as an operation")
+      .def_property_readonly(
+          "body",
+          [](PyModule &self) {
+            PyOperationRef module_op = PyOperation::forOperation(
+                self.getContext(), mlirModuleGetOperation(self.get()),
+                self.getRef().releaseObject());
+            PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
+            return returnBlock;
+          },
+          "Return the block for this module")
       .def(
           "dump",
           [](PyModule &self) {
index fdc40bc..f3c91d1 100644 (file)
@@ -148,6 +148,10 @@ MlirContext mlirModuleGetContext(MlirModule module) {
   return wrap(unwrap(module).getContext());
 }
 
+MlirBlock mlirModuleGetBody(MlirModule module) {
+  return wrap(unwrap(module).getBody());
+}
+
 void mlirModuleDestroy(MlirModule module) {
   // Transfer ownership to an OwningModuleRef so that its destructor is called.
   OwningModuleRef(unwrap(module));
index bc88e86..ef95163 100644 (file)
@@ -73,7 +73,7 @@ def testCustomOpView():
   f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
   m = ctx.create_module(loc)
-  m_block = m.operation.regions[0].blocks[0]
+  m_block = m.body
   # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
   ip = [0]
   def createInput():
index 1d382c3..87c8b64 100644 (file)
@@ -67,9 +67,7 @@ void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
 
 MlirModule makeAdd(MlirContext ctx, MlirLocation location) {
   MlirModule moduleOp = mlirModuleCreateEmpty(location);
-  MlirOperation module = mlirModuleGetOperation(moduleOp);
-  MlirRegion moduleBodyRegion = mlirOperationGetRegion(module, 0);
-  MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion);
+  MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
 
   MlirType memrefType = mlirTypeParseGet(ctx, "memref<?xf32>");
   MlirType funcBodyArgTypes[] = {memrefType, memrefType};