[mlir] Add C and python API for is_registered_operation.
authorStella Laurenzo <stellaraccident@gmail.com>
Wed, 31 Mar 2021 05:19:10 +0000 (22:19 -0700)
committerStella Laurenzo <stellaraccident@gmail.com>
Wed, 31 Mar 2021 05:56:02 +0000 (22:56 -0700)
* Suggested to be broken out of D99578

Differential Revision: https://reviews.llvm.org/D99638

mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/dialects.py
mlir/test/CAPI/ir.c

index d807cd4..048bd46 100644 (file)
@@ -119,6 +119,13 @@ mlirContextGetNumLoadedDialects(MlirContext context);
 MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
                                                            MlirStringRef name);
 
+/// Returns whether the given fully-qualified operation (i.e.
+/// 'dialect.operation') is registered with the context. This will return true
+/// if the dialect is loaded and the operation is registered within the
+/// dialect.
+MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
+                                                         MlirStringRef name);
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//
index 0a4c5fc..5046eed 100644 (file)
@@ -1752,7 +1752,12 @@ void mlir::python::populateIRCore(py::module &m) {
           },
           [](PyMlirContext &self, bool value) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
-          });
+          })
+      .def("is_registered_operation",
+           [](PyMlirContext &self, std::string &name) {
+             return mlirContextIsRegisteredOperation(
+                 self.get(), MlirStringRef{name.data(), name.size()});
+           });
 
   //----------------------------------------------------------------------------
   // Mapping of PyDialectDescriptor
index 67032a4..14cde96 100644 (file)
@@ -60,6 +60,10 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
 }
 
+bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
+  return unwrap(context)->isOperationRegistered(unwrap(name));
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//
index 41f4239..d5f5bee 100644 (file)
@@ -3,14 +3,17 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testDialectDescriptor
+@run
 def testDialectDescriptor():
   ctx = Context()
   d = ctx.get_dialect_descriptor("std")
@@ -25,10 +28,9 @@ def testDialectDescriptor():
   else:
     assert False, "Expected exception"
 
-run(testDialectDescriptor)
-
 
 # CHECK-LABEL: TEST: testUserDialectClass
+@run
 def testUserDialectClass():
   ctx = Context()
   # Access using attribute.
@@ -60,14 +62,14 @@ def testUserDialectClass():
   # CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
   print(d)
 
-run(testUserDialectClass)
-
 
 # CHECK-LABEL: TEST: testCustomOpView
 # This test uses the standard dialect AddFOp as an example of a user op.
 # TODO: Op creation and access is still quite verbose: simplify this test as
 # additional capabilities come online.
+@run
 def testCustomOpView():
+
   def createInput():
     op = Operation.create("pytest_dummy.intinput", results=[f32])
     # TODO: Auto result cast from operation
@@ -95,4 +97,12 @@ def testCustomOpView():
   m.operation.print()
 
 
-run(testCustomOpView)
+# CHECK-LABEL: TEST: testIsRegisteredOperation
+@run
+def testIsRegisteredOperation():
+  ctx = Context()
+
+  # CHECK: std.cond_br: True
+  print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}")
+  # CHECK: std.not_existing: False
+  print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}")
index 40ef39b..5ce496c 100644 (file)
@@ -1442,6 +1442,22 @@ int registerOnlyStd() {
   fprintf(stderr, "@registration\n");
   // CHECK-LABEL: @registration
 
+  // CHECK: std.cond_br is_registered: 1
+  fprintf(stderr, "std.cond_br is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString("std.cond_br")));
+
+  // CHECK: std.not_existing_op is_registered: 0
+  fprintf(stderr, "std.not_existing_op is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString("std.not_existing_op")));
+
+  // CHECK: not_existing_dialect.not_existing_op is_registered: 0
+  fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString(
+                       "not_existing_dialect.not_existing_op")));
+
   return 0;
 }