Emit strong definition for TypeID storage in Op/Type/Attributes definition
authorMehdi Amini <joker.eph@gmail.com>
Wed, 28 Jul 2021 05:22:45 +0000 (05:22 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 28 Jul 2021 23:58:39 +0000 (23:58 +0000)
By making an explicit template specialization for the TypeID provided by these classes,
the compiler will not emit an inline weak definition and rely on the linker to unique it.
Instead a single definition will be emitted in the C++ file alongside the implementation
for these classes. That will turn into a linker error what is now a hard-to-debug runtime
behavior where instances of the same class may be using a different TypeID inside of
different DSOs.

Recommit 660a56956c32b0bcd850fc12fa8ad0225a6bb880 after fixing gcc5
build.

Differential Revision: https://reviews.llvm.org/D105903

mlir/include/mlir/Support/TypeID.h
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 2e4d8b9..9e18a3c 100644 (file)
@@ -137,6 +137,32 @@ TypeID TypeID::get() {
 
 } // end namespace mlir
 
+// Declare/define an explicit specialization for TypeID: this forces the
+// compiler to emit a strong definition for a class and controls which
+// translation unit and shared object will actually have it.
+// This can be useful to turn to a link-time failure what would be in other
+// circumstances a hard-to-catch runtime bug when a TypeID is hidden in two
+// different shared libraries and instances of the same class only gets the same
+// TypeID inside a given DSO.
+#define DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME)                                   \
+  namespace mlir {                                                             \
+  namespace detail {                                                           \
+  template <>                                                                  \
+  LLVM_EXTERNAL_VISIBILITY TypeID TypeIDExported::get<CLASS_NAME>();           \
+  }                                                                            \
+  }
+
+#define DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)                                    \
+  namespace mlir {                                                             \
+  namespace detail {                                                           \
+  template <>                                                                  \
+  LLVM_EXTERNAL_VISIBILITY TypeID TypeIDExported::get<CLASS_NAME>() {          \
+    static TypeID::Storage instance;                                           \
+    return TypeID(&instance);                                                  \
+  }                                                                            \
+  }                                                                            \
+  }
+
 namespace llvm {
 template <> struct DenseMapInfo<mlir::TypeID> {
   static mlir::TypeID getEmptyKey() {
index c8910ea..5b1b803 100644 (file)
@@ -440,16 +440,24 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
   collectAllDefs(selectedDialect, defRecords, defs);
   if (defs.empty())
     return false;
+  {
+    NamespaceEmitter nsEmitter(os, defs.front().getDialect());
 
-  NamespaceEmitter nsEmitter(os, defs.front().getDialect());
+    // Declare all the def classes first (in case they reference each other).
+    for (const AttrOrTypeDef &def : defs)
+      os << "  class " << def.getCppClassName() << ";\n";
 
-  // Declare all the def classes first (in case they reference each other).
+    // Emit the declarations.
+    for (const AttrOrTypeDef &def : defs)
+      emitDefDecl(def);
+  }
+  // Emit the TypeID explicit specializations to have a single definition for
+  // each of these.
   for (const AttrOrTypeDef &def : defs)
-    os << "  class " << def.getCppClassName() << ";\n";
+    if (!def.getDialect().getCppNamespace().empty())
+      os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
+         << "::" << def.getCppClassName() << ")\n";
 
-  // Emit the declarations.
-  for (const AttrOrTypeDef &def : defs)
-    emitDefDecl(def);
   return false;
 }
 
@@ -934,8 +942,13 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
 
   IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
   emitParsePrintDispatch(defs);
-  for (const AttrOrTypeDef &def : defs)
+  for (const AttrOrTypeDef &def : defs) {
     emitDefDef(def);
+    // Emit the TypeID explicit specializations to have a single symbol def.
+    if (!def.getDialect().getCppNamespace().empty())
+      os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
+         << "::" << def.getCppClassName() << ")\n";
+  }
 
   return false;
 }
index 2ebabc5..2e5b983 100644 (file)
@@ -198,38 +198,44 @@ static void emitDialectDecl(Dialect &dialect,
   }
 
   // Emit all nested namespaces.
-  NamespaceEmitter nsEmitter(os, dialect);
-
-  // Emit the start of the decl.
-  std::string cppName = dialect.getCppClassName();
-  os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
-                      dependentDialectRegistrations);
-
-  // Check for any attributes/types registered to this dialect.  If there are,
-  // add the hooks for parsing/printing.
-  if (!dialectAttrs.empty())
-    os << attrParserDecl;
-  if (!dialectTypes.empty())
-    os << typeParserDecl;
-
-  // Add the decls for the various features of the dialect.
-  if (dialect.hasCanonicalizer())
-    os << canonicalizerDecl;
-  if (dialect.hasConstantMaterializer())
-    os << constantMaterializerDecl;
-  if (dialect.hasOperationAttrVerify())
-    os << opAttrVerifierDecl;
-  if (dialect.hasRegionArgAttrVerify())
-    os << regionArgAttrVerifierDecl;
-  if (dialect.hasRegionResultAttrVerify())
-    os << regionResultAttrVerifierDecl;
-  if (dialect.hasOperationInterfaceFallback())
-    os << operationInterfaceFallbackDecl;
-  if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
-    os << *extraDecl;
-
-  // End the dialect decl.
-  os << "};\n";
+  {
+    NamespaceEmitter nsEmitter(os, dialect);
+
+    // Emit the start of the decl.
+    std::string cppName = dialect.getCppClassName();
+    os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
+                        dependentDialectRegistrations);
+
+    // Check for any attributes/types registered to this dialect.  If there are,
+    // add the hooks for parsing/printing.
+    if (!dialectAttrs.empty())
+      os << attrParserDecl;
+    if (!dialectTypes.empty())
+      os << typeParserDecl;
+
+    // Add the decls for the various features of the dialect.
+    if (dialect.hasCanonicalizer())
+      os << canonicalizerDecl;
+    if (dialect.hasConstantMaterializer())
+      os << constantMaterializerDecl;
+    if (dialect.hasOperationAttrVerify())
+      os << opAttrVerifierDecl;
+    if (dialect.hasRegionArgAttrVerify())
+      os << regionArgAttrVerifierDecl;
+    if (dialect.hasRegionResultAttrVerify())
+      os << regionResultAttrVerifierDecl;
+    if (dialect.hasOperationInterfaceFallback())
+      os << operationInterfaceFallbackDecl;
+    if (llvm::Optional<StringRef> extraDecl =
+            dialect.getExtraClassDeclaration())
+      os << *extraDecl;
+
+    // End the dialect decl.
+    os << "};\n";
+  }
+  if (!dialect.getCppNamespace().empty())
+    os << "DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
+       << "::" << dialect.getCppClassName() << ")\n";
 }
 
 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
@@ -263,6 +269,11 @@ static const char *const dialectDestructorStr = R"(
 )";
 
 static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+  // Emit the TypeID explicit specializations to have a single symbol def.
+  if (!dialect.getCppNamespace().empty())
+    os << "DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
+       << "::" << dialect.getCppClassName() << ")\n";
+
   // Emit all nested namespaces.
   NamespaceEmitter nsEmitter(os, dialect);
 
index 2bc9ea4..1c630e0 100644 (file)
@@ -650,7 +650,6 @@ OpEmitter::OpEmitter(const Operator &op,
   generateOpFormat(op, opClass);
   genSideEffectInterfaceMethods();
 }
-
 void OpEmitter::emitDecl(
     const Operator &op, raw_ostream &os,
     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
@@ -2576,15 +2575,29 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
                                                       emitDecl);
   for (auto *def : defs) {
     Operator op(*def);
-    NamespaceEmitter emitter(os, op.getCppNamespace());
     if (emitDecl) {
-      os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
-      OpOperandAdaptorEmitter::emitDecl(op, os);
-      OpEmitter::emitDecl(op, os, staticVerifierEmitter);
+      {
+        NamespaceEmitter emitter(os, op.getCppNamespace());
+        os << formatv(opCommentHeader, op.getQualCppClassName(),
+                      "declarations");
+        OpOperandAdaptorEmitter::emitDecl(op, os);
+        OpEmitter::emitDecl(op, os, staticVerifierEmitter);
+      }
+      // Emit the TypeID explicit specialization to have a single definition.
+      if (!op.getCppNamespace().empty())
+        os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
+           << "::" << op.getCppClassName() << ")\n\n";
     } else {
-      os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
-      OpOperandAdaptorEmitter::emitDef(op, os);
-      OpEmitter::emitDef(op, os, staticVerifierEmitter);
+      {
+        NamespaceEmitter emitter(os, op.getCppNamespace());
+        os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
+        OpOperandAdaptorEmitter::emitDef(op, os);
+        OpEmitter::emitDef(op, os, staticVerifierEmitter);
+      }
+      // Emit the TypeID explicit specialization to have a single definition.
+      if (!op.getCppNamespace().empty())
+        os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
+           << "::" << op.getCppClassName() << ")\n\n";
     }
   }
 }