Fix use-after-free in SymbolTable::replaceAllSymbolUses
authorMehdi Amini <joker.eph@gmail.com>
Tue, 2 Aug 2022 22:18:36 +0000 (22:18 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 2 Aug 2022 22:30:17 +0000 (22:30 +0000)
In some cases the recursion will grow the `visited` hash table and
invalidate the cached iterator.
(caught with ASAN)

Differential Revision: https://reviews.llvm.org/D131027

mlir/lib/IR/SubElementInterfaces.cpp

index f8526dc..a362479 100644 (file)
@@ -117,31 +117,39 @@ static void updateSubElementImpl(
 
   // Check for an existing mapping for this element, and walk it if we haven't
   // yet.
-  T &mappedElement = visited[element];
-  if (!mappedElement) {
+  T *mappedElement = &visited[element];
+  if (!*mappedElement) {
     WalkResult result = WalkResult::advance();
-    std::tie(mappedElement, result) = walkFn(element);
+    std::tie(*mappedElement, result) = walkFn(element);
 
     // Try walking this element.
-    if (result.wasInterrupted() || !mappedElement) {
+    if (result.wasInterrupted() || !*mappedElement) {
       changed = failure();
       return;
     }
 
     // Handle replacing sub-elements if this element is also a container.
     if (!result.wasSkipped()) {
-      if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
-        if (!(mappedElement = replaceSubElementFn(interface))) {
+      if (auto interface = mappedElement->template dyn_cast<InterfaceT>()) {
+        // Cache the size of the `visited` map since it may grow when calling
+        // `replaceSubElementFn` and we would need to fetch again the (now
+        // invalidated) reference to `mappedElement`.
+        size_t visitedSize = visited.size();
+        auto replacedElement = replaceSubElementFn(interface);
+        if (!replacedElement) {
           changed = failure();
           return;
         }
+        if (visitedSize != visited.size())
+          mappedElement = &visited[element];
+        *mappedElement = replacedElement;
       }
     }
   }
 
   // Update to the mapped element.
-  if (mappedElement != element) {
-    newElements.back() = mappedElement;
+  if (*mappedElement != element) {
+    newElements.back() = *mappedElement;
     changed = true;
   }
 }