[MLIR] Replace dialect registration hooks with dialect handle
authorGeorge <GeorgeLyon@users.noreply.github.com>
Tue, 9 Feb 2021 17:00:22 +0000 (09:00 -0800)
committerGeorge <GeorgeLyon@users.noreply.github.com>
Tue, 9 Feb 2021 17:02:16 +0000 (09:02 -0800)
Replace MlirDialectRegistrationHooks with MlirDialectHandle, which under-the-hood is an opaque pointer to MlirDialectRegistrationHooks. Then we expose the functionality previously directly on MlirDialectRegistrationHooks, as functions which take the opaque MlirDialectHandle struct. This makes the actual structure of the registration hooks an implementation detail, and happens to avoid this issue: https://llvm.discourse.group/t/strange-swift-issues-with-dialect-registration-hooks/2759/3

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D96229

mlir/include/mlir-c/Registration.h
mlir/include/mlir/CAPI/Registration.h
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/lib/CAPI/IR/DialectHandle.cpp [new file with mode: 0644]
mlir/test/CAPI/ir.c

index 7fde05d..6c7a486 100644 (file)
@@ -23,47 +23,34 @@ extern "C" {
 // API name (i.e. "Standard", "Tensor", "Linalg") and namespace (i.e. "std",
 // "tensor", "linalg"). The following declarations are produced:
 //
-//   /// Registers the dialect with the given context. This allows the
-//   /// dialect to be loaded dynamically if needed when parsing. */
-//   void mlirContextRegister{NAME}Dialect(MlirContext);
-//
-//   /// Loads the dialect into the given context. The dialect does _not_
-//   /// have to be registered in advance.
-//   MlirDialect mlirContextLoad{NAME}Dialect(MlirContext context);
-//
-//   /// Returns the namespace of the Standard dialect, suitable for loading it.
-//   MlirStringRef mlir{NAME}DialectGetNamespace();
-//
 //   /// Gets the above hook methods in struct form for a dialect by namespace.
 //   /// This is intended to facilitate dynamic lookup and registration of
 //   /// dialects via a plugin facility based on shared library symbol lookup.
-//   const MlirDialectRegistrationHooks *mlirGetDialectHooks__{NAMESPACE}__();
+//   const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__();
 //
 // This is done via a common macro to facilitate future expansion to
 // registration schemes.
 //===----------------------------------------------------------------------===//
 
+struct MlirDialectHandle {
+  const void *ptr;
+};
+typedef struct MlirDialectHandle MlirDialectHandle;
+
 #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace)                \
-  MLIR_CAPI_EXPORTED void mlirContextRegister##Name##Dialect(                  \
-      MlirContext context);                                                    \
-  MLIR_CAPI_EXPORTED MlirDialect mlirContextLoad##Name##Dialect(               \
-      MlirContext context);                                                    \
-  MLIR_CAPI_EXPORTED MlirStringRef mlir##Name##DialectGetNamespace();          \
-  MLIR_CAPI_EXPORTED const MlirDialectRegistrationHooks                        \
-      *mlirGetDialectHooks__##Namespace##__()
+  MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__()
 
-/// Hooks for dynamic discovery of dialects.
-typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
-typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
-typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
+/// Returns the namespace associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED
+MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
 
-/// Structure of dialect registration hooks.
-struct MlirDialectRegistrationHooks {
-  MlirContextRegisterDialectHook registerHook;
-  MlirContextLoadDialectHook loadHook;
-  MlirDialectGetNamespaceHook getNamespaceHook;
-};
-typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
+/// Registers the dialect associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
+                                                         MlirContext);
+
+/// Loads the dialect associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle,
+                                                            MlirContext);
 
 /// Registers all dialects known to core MLIR with the provided Context.
 /// This is needed before creating IR for these Dialects.
index da63afb..7601f9f 100644 (file)
 // of the dialect class.
 //===----------------------------------------------------------------------===//
 
+/// Hooks for dynamic discovery of dialects.
+typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
+typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
+typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
+
+/// Structure of dialect registration hooks.
+struct MlirDialectRegistrationHooks {
+  MlirContextRegisterDialectHook registerHook;
+  MlirContextLoadDialectHook loadHook;
+  MlirDialectGetNamespaceHook getNamespaceHook;
+};
+typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
+
 #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)      \
-  void mlirContextRegister##Name##Dialect(MlirContext context) {               \
+  static void mlirContextRegister##Name##Dialect(MlirContext context) {        \
     unwrap(context)->getDialectRegistry().insert<ClassName>();                 \
   }                                                                            \
-  MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {            \
+  static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {     \
     return wrap(unwrap(context)->getOrLoadDialect<ClassName>());               \
   }                                                                            \
-  MlirStringRef mlir##Name##DialectGetNamespace() {                            \
+  static MlirStringRef mlir##Name##DialectGetNamespace() {                     \
     return wrap(ClassName::getDialectNamespace());                             \
   }                                                                            \
-  const MlirDialectRegistrationHooks *mlirGetDialectHooks__##Namespace##__() { \
+  MlirDialectHandle mlirGetDialectHandle__##Namespace##__() {                  \
     static MlirDialectRegistrationHooks hooks = {                              \
         mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect,    \
         mlir##Name##DialectGetNamespace};                                      \
-    return &hooks;                                                             \
+    return MlirDialectHandle{&hooks};                                          \
   }
 
 #endif // MLIR_CAPI_REGISTRATION_H
index 893ccb6..486ba6e 100644 (file)
@@ -5,6 +5,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR
   BuiltinAttributes.cpp
   BuiltinTypes.cpp
   Diagnostics.cpp
+  DialectHandle.cpp
   IntegerSet.cpp
   IR.cpp
   Pass.cpp
diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp
new file mode 100644 (file)
index 0000000..fb97231
--- /dev/null
@@ -0,0 +1,28 @@
+//===- DialectHandle.cpp - C Interface for MLIR Dialect Operations -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/CAPI/Registration.h"
+
+static inline const MlirDialectRegistrationHooks *
+unwrap(MlirDialectHandle handle) {
+  return (const MlirDialectRegistrationHooks *)handle.ptr;
+}
+
+MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
+  return unwrap(handle)->getNamespaceHook();
+}
+
+void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
+                                      MlirContext ctx) {
+  unwrap(handle)->registerHook(ctx);
+}
+
+MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,
+                                         MlirContext ctx) {
+  return unwrap(handle)->loadHook(ctx);
+}
index 2f81d13..7576133 100644 (file)
@@ -1412,23 +1412,26 @@ int registerOnlyStd() {
   if (mlirContextGetNumLoadedDialects(ctx) != 1)
     return 1;
 
-  MlirDialect std =
-      mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
+
+  MlirDialect std = mlirContextGetOrLoadDialect(
+      ctx, mlirDialectHandleGetNamespace(stdHandle));
   if (!mlirDialectIsNull(std))
     return 2;
 
-  mlirContextRegisterStandardDialect(ctx);
+  mlirDialectHandleRegisterDialect(stdHandle, ctx);
 
-  std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  std = mlirContextGetOrLoadDialect(ctx,
+                                    mlirDialectHandleGetNamespace(stdHandle));
   if (mlirDialectIsNull(std))
     return 3;
 
-  MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
+  MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
   if (!mlirDialectEqual(std, alsoStd))
     return 4;
 
   MlirStringRef stdNs = mlirDialectGetNamespace(std);
-  MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
+  MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
   if (stdNs.length != alsoStdNs.length ||
       strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
     return 5;