From 4f0262c1640531dd431cf205f4b802b1fabb6489 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 2 Aug 2022 22:18:36 +0000 Subject: [PATCH] Fix use-after-free in SymbolTable::replaceAllSymbolUses In some cases the recursion will grow the `visited` hash table and invalidate the cached iterator. (caught with ASAN) Differential Revision: https://reviews.llvm.org/D131027 --- mlir/lib/IR/SubElementInterfaces.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp index f8526dc..a362479 100644 --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -117,31 +117,39 @@ static void updateSubElementImpl( // Check for an existing mapping for this element, and walk it if we haven't // yet. - T &mappedElement = visited[element]; - if (!mappedElement) { + T *mappedElement = &visited[element]; + if (!*mappedElement) { WalkResult result = WalkResult::advance(); - std::tie(mappedElement, result) = walkFn(element); + std::tie(*mappedElement, result) = walkFn(element); // Try walking this element. - if (result.wasInterrupted() || !mappedElement) { + if (result.wasInterrupted() || !*mappedElement) { changed = failure(); return; } // Handle replacing sub-elements if this element is also a container. if (!result.wasSkipped()) { - if (auto interface = mappedElement.template dyn_cast()) { - if (!(mappedElement = replaceSubElementFn(interface))) { + if (auto interface = mappedElement->template dyn_cast()) { + // Cache the size of the `visited` map since it may grow when calling + // `replaceSubElementFn` and we would need to fetch again the (now + // invalidated) reference to `mappedElement`. + size_t visitedSize = visited.size(); + auto replacedElement = replaceSubElementFn(interface); + if (!replacedElement) { changed = failure(); return; } + if (visitedSize != visited.size()) + mappedElement = &visited[element]; + *mappedElement = replacedElement; } } } // Update to the mapped element. - if (mappedElement != element) { - newElements.back() = mappedElement; + if (*mappedElement != element) { + newElements.back() = *mappedElement; changed = true; } } -- 2.7.4