From 74a8a1e038022fb4ca9b8e444489e910f16a9741 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 5 Apr 2023 15:35:02 +0200 Subject: [PATCH] [mlir] Fix a use after free when loading dependent dialects The way dependent dialects are implemented is by recursively calling loadDialect in the constructor. This means we have to reload from the dialect table because the constructor might have rehashed that table. The steps for loading a dialect are 1. Insert a nullptr into loadedDialects. This indicates the dialect is loading 2. Call ctor(). This recursively loads dependent dialects 3. Insert the new dialect into the table. We had a conflict between steps 2 and 3 here. You have to be extremely unlucky though as rehashing is rare and operator[] does no generation checking on DenseMap. Changing that to an iterator would've uncovered this issue immediately. --- mlir/lib/IR/MLIRContext.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index daa4a6a..e64babf 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -438,9 +438,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, function_ref()> ctor) { auto &impl = getImpl(); // Get the correct insertion position sorted by namespace. - auto dialectIt = impl.loadedDialects.find(dialectNamespace); + auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr); - if (dialectIt == impl.loadedDialects.end()) { + if (dialectIt.second) { LLVM_DEBUG(llvm::dbgs() << "Load new dialect in Context " << dialectNamespace << "\n"); #ifndef NDEBUG @@ -452,9 +452,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, "missing `dependentDialects` in a pass for example."); #endif // NDEBUG // loadedDialects entry is initialized to nullptr, indicating that the - // dialect is currently being loaded. - std::unique_ptr &dialect = impl.loadedDialects[dialectNamespace]; - dialect = ctor(); + // dialect is currently being loaded. Re-lookup the address in + // loadedDialects because the table might have been rehashed by recursive + // dialect loading in ctor(). + std::unique_ptr &dialect = impl.loadedDialects[dialectNamespace] = + ctor(); assert(dialect && "dialect ctor failed"); // Refresh all the identifiers dialect field, this catches cases where a @@ -473,7 +475,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, } #ifndef NDEBUG - if (dialectIt->second == nullptr) + if (dialectIt.first->second == nullptr) llvm::report_fatal_error( "Loading (and getting) a dialect (" + dialectNamespace + ") while the same dialect is still loading: use loadDialect instead " @@ -481,7 +483,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, #endif // NDEBUG // Abort if dialect with namespace has already been registered. - std::unique_ptr &dialect = dialectIt->second; + std::unique_ptr &dialect = dialectIt.first->second; if (dialect->getTypeID() != dialectID) llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + "' has already been registered"); -- 2.7.4