[mlir] Fix a use after free when loading dependent dialects
authorBenjamin Kramer <benny.kra@googlemail.com>
Wed, 5 Apr 2023 13:35:02 +0000 (15:35 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 5 Apr 2023 13:44:29 +0000 (15:44 +0200)
The way dependent dialects are implemented is by recursively calling
loadDialect in the constructor. This means we have to reload from the
dialect table because the constructor might have rehashed that table.

The steps for loading a dialect are
  1. Insert a nullptr into loadedDialects. This indicates the dialect is
     loading
  2. Call ctor(). This recursively loads dependent dialects
  3. Insert the new dialect into the table.

We had a conflict between steps 2 and 3 here. You have to be extremely
unlucky though as rehashing is rare and operator[] does no generation
checking on DenseMap. Changing that to an iterator would've uncovered
this issue immediately.

mlir/lib/IR/MLIRContext.cpp

index daa4a6a..e64babf 100644 (file)
@@ -438,9 +438,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
                               function_ref<std::unique_ptr<Dialect>()> ctor) {
   auto &impl = getImpl();
   // Get the correct insertion position sorted by namespace.
-  auto dialectIt = impl.loadedDialects.find(dialectNamespace);
+  auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);
 
-  if (dialectIt == impl.loadedDialects.end()) {
+  if (dialectIt.second) {
     LLVM_DEBUG(llvm::dbgs()
                << "Load new dialect in Context " << dialectNamespace << "\n");
 #ifndef NDEBUG
@@ -452,9 +452,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
           "missing `dependentDialects` in a pass for example.");
 #endif // NDEBUG
     // loadedDialects entry is initialized to nullptr, indicating that the
-    // dialect is currently being loaded.
-    std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
-    dialect = ctor();
+    // dialect is currently being loaded. Re-lookup the address in
+    // loadedDialects because the table might have been rehashed by recursive
+    // dialect loading in ctor().
+    std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
+        ctor();
     assert(dialect && "dialect ctor failed");
 
     // Refresh all the identifiers dialect field, this catches cases where a
@@ -473,7 +475,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   }
 
 #ifndef NDEBUG
-  if (dialectIt->second == nullptr)
+  if (dialectIt.first->second == nullptr)
     llvm::report_fatal_error(
         "Loading (and getting) a dialect (" + dialectNamespace +
         ") while the same dialect is still loading: use loadDialect instead "
@@ -481,7 +483,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
 #endif // NDEBUG
 
   // Abort if dialect with namespace has already been registered.
-  std::unique_ptr<Dialect> &dialect = dialectIt->second;
+  std::unique_ptr<Dialect> &dialect = dialectIt.first->second;
   if (dialect->getTypeID() != dialectID)
     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
                              "' has already been registered");