Add `mlirModuleFromOperation` to C API
authorAdam Paszke <apaszke@google.com>
Mon, 17 May 2021 10:14:02 +0000 (10:14 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 17 May 2021 10:14:16 +0000 (10:14 +0000)
At the moment `MlirModule`s can be converted to `MlirOperation`s, but not
the other way around (at least not without going around the C API). This
makes it impossible to e.g. run passes over a `ModuleOp` created through
`mlirOperationCreate`.

Reviewed By: nicolasvasilache, mehdi_amini

Differential Revision: https://reviews.llvm.org/D102497

mlir/include/mlir-c/IR.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c

index 1b24316..638eea9 100644 (file)
@@ -209,6 +209,10 @@ MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module);
 /// Views the module as a generic operation.
 MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
 
+/// Views the generic operation as a module.
+/// The returned module is null when the input operation was not a ModuleOp.
+MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
+
 //===----------------------------------------------------------------------===//
 // Operation state.
 //===----------------------------------------------------------------------===//
index 4e21835..ebabd68 100644 (file)
@@ -181,6 +181,10 @@ MlirOperation mlirModuleGetOperation(MlirModule module) {
   return wrap(unwrap(module).getOperation());
 }
 
+MlirModule mlirModuleFromOperation(MlirOperation op) {
+  return wrap(dyn_cast<ModuleOp>(unwrap(op)));
+}
+
 //===----------------------------------------------------------------------===//
 // Operation state API.
 //===----------------------------------------------------------------------===//
index 7176cbb..42dfa53 100644 (file)
@@ -319,6 +319,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   MlirOperation parentOperation = operation;
   block = mlirRegionGetFirstBlock(region);
   operation = mlirBlockGetFirstOperation(block);
+  assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
 
   // Verify that parent operation and block report correctly.
   fprintf(stderr, "Parent operation eq: %d\n",
@@ -460,6 +461,7 @@ static int constructAndTraverseIr(MlirContext ctx) {
 
   MlirModule moduleOp = makeAndDumpAdd(ctx, location);
   MlirOperation module = mlirModuleGetOperation(moduleOp);
+  assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
 
   int errcode = collectStats(module);
   if (errcode)