[mlir] Add a new AttrTypeReplacer class to simplify sub element replacements
authorRiver Riddle <riddleriver@gmail.com>
Thu, 10 Nov 2022 04:59:40 +0000 (20:59 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 12 Nov 2022 22:38:45 +0000 (14:38 -0800)
We currently only have the SubElement interface API for attribute/type
replacement, but this suffers from several issues; namely that it doesn't
allow caching across multiple replacements (very common), and also
creates a somewhat awkward/limited API. The new AttrTypeReplacer class
allows for registering replacements using a much cleaner API, similarly to
the TypeConverter class, removes a lot of manual interaction with the
sub element interfaces, and also better enables large scale replacements.

Differential Revision: https://reviews.llvm.org/D137764

mlir/include/mlir/IR/SubElementInterfaces.h
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/SubElementInterfaces.cpp
mlir/lib/IR/SymbolTable.cpp

index 2af7642..0162692 100644 (file)
 #include "mlir/IR/Visitors.h"
 
 namespace mlir {
-template <typename T>
-using SubElementReplFn = function_ref<T(T)>;
-template <typename T>
-using SubElementResultReplFn = function_ref<std::pair<T, WalkResult>(T)>;
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+/// This class provides a utility for replacing attributes/types, and their sub
+/// elements. Multiple replacement functions may be registered.
+class AttrTypeReplacer {
+public:
+  //===--------------------------------------------------------------------===//
+  // Application
+  //===--------------------------------------------------------------------===//
+
+  /// Replace the elements within the given operation. By default this includes
+  /// the attributes within the operation. If `replaceLocs` is true, this also
+  /// updates its location, the locations of any nested block arguments. If
+  /// `replaceTypes` is true, this also updates the result types of the
+  /// operation, and the types of any nested block arguments.
+  void replaceElementsIn(Operation *op, bool replaceLocs = false,
+                         bool replaceTypes = false);
+
+  /// Replace the given attribute/type, and recursively replace any sub
+  /// elements. Returns either the new attribute/type, or nullptr in the case of
+  /// failure.
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+  //===--------------------------------------------------------------------===//
+  // Registration
+  //===--------------------------------------------------------------------===//
+
+  /// A replacement mapping function, which returns either None (to signal the
+  /// element wasn't handled), or a pair of the replacement element and a
+  /// WalkResult.
+  template <typename T>
+  using ReplaceFnResult = Optional<std::pair<T, WalkResult>>;
+  template <typename T>
+  using ReplaceFn = std::function<ReplaceFnResult<T>(T)>;
+
+  /// Register a replacement function for mapping a given attribute or type. A
+  /// replacement function must be convertible to any of the following
+  /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT`
+  /// is either `Type` or `Attribute` respectively):
+  ///
+  ///   * Optional<BaseT>(T)
+  ///     - This either returns a valid Attribute/Type in the case of success,
+  ///       nullptr in the case of failure, or `llvm::None` to signify that
+  ///       additional replacement functions may be applied (i.e. this function
+  ///       doesn't handle that instance).
+  ///
+  ///   * Optional<std::pair<BaseT, WalkResult>>(T)
+  ///     - Similar to the above, but also allows specifying a WalkResult to
+  ///       control the replacement of sub elements of a given attribute or
+  ///       type. Returning a `skip` result, for example, will not recursively
+  ///       process the resultant attribute or type value.
+  ///
+  /// Note: When replacing, the mostly recently added replacement functions will
+  ///       be invoked first.
+  void addReplacement(ReplaceFn<Attribute> fn) {
+    attrReplacementFns.emplace_back(std::move(fn));
+  }
+  void addReplacement(ReplaceFn<Type> fn) {
+    typeReplacementFns.push_back(std::move(fn));
+  }
+
+  /// Register a replacement function that doesn't match the default signature,
+  /// either because it uses a derived parameter type, or it uses a simplified
+  /// result type.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<
+                std::decay_t<FnT>>::template arg_t<0>,
+            typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+                                                Attribute, Type>,
+            typename ResultT = std::invoke_result_t<FnT, T>>
+  std::enable_if_t<!std::is_same_v<T, BaseT> ||
+                   !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
+  addReplacement(FnT &&callback) {
+    addReplacement([callback = std::forward<FnT>(callback)](
+                       BaseT base) -> ReplaceFnResult<BaseT> {
+      if (auto derived = dyn_cast<T>(base)) {
+        if constexpr (std::is_convertible_v<ResultT, Optional<BaseT>>) {
+          Optional<BaseT> result = callback(derived);
+          return result ? std::make_pair(*result, WalkResult::advance())
+                        : ReplaceFnResult<BaseT>();
+        } else {
+          return callback(derived);
+        }
+      }
+      return ReplaceFnResult<BaseT>();
+    });
+  }
+
+private:
+  /// Internal implementation of the `replace` methods above.
+  template <typename InterfaceT, typename ReplaceFns, typename T>
+  T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap<T, T> &map);
+
+  /// Replace the sub elements of the given interface.
+  template <typename InterfaceT, typename T = typename InterfaceT::ValueType>
+  T replaceSubElements(InterfaceT interface, DenseMap<T, T> &interfaceMap);
+
+  /// The set of replacement functions that map sub elements.
+  std::vector<ReplaceFn<Attribute>> attrReplacementFns;
+  std::vector<ReplaceFn<Type>> typeReplacementFns;
+
+  /// The set of cached mappings for attributes/types.
+  DenseMap<Attribute, Attribute> attrMap;
+  DenseMap<Type, Type> typeMap;
+};
 
 //===----------------------------------------------------------------------===//
 /// AttrTypeSubElementHandler
@@ -291,7 +395,7 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
 } // namespace detail
 } // namespace mlir
 
-/// Include the definitions of the sub elemnt interfaces.
+/// Include the definitions of the sub element interfaces.
 #include "mlir/IR/SubElementAttrInterfaces.h.inc"
 #include "mlir/IR/SubElementTypeInterfaces.h.inc"
 
index abb5afc..7718feb 100644 (file)
@@ -66,25 +66,14 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
                          llvm::function_ref<void(mlir::Type)> walkTypesFn);
 
     /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
-      mlir::SubElementReplFn<mlir::Type> replaceTypeFn
-    ) {
-      return replaceSubElements(
-        [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); },
-        [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); }
-      );
+    /// provided map functions. Returns nullptr in the case of failure. See
+    /// `AttrTypeReplacer` for information on the support replacement function types.
+    template <typename... ReplacementFns>
+    }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
+      AttrTypeReplacer replacer;
+      (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
+      return replacer.replace(*this);
     }
-    /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. This variant allows for the map function to return an
-    /// additional walk result. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceSubElements(
-      llvm::function_ref<
-        std::pair<mlir::Attribute, mlir::WalkResult>(mlir::Attribute)> replaceAttrFn,
-      llvm::function_ref<
-        std::pair<mlir::Type, mlir::WalkResult>(mlir::Type)> replaceTypeFn
-    );
   }];
   code extraTraitClassDeclaration = [{
     /// Walk all of the held sub-attributes and sub-types.
@@ -95,18 +84,13 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
     }
 
     /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
-      mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
-      }] # interfaceName # " interface(" # derivedValue # [{);
-      return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
-    }
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn,
-      mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
-      }] # interfaceName # " interface(" # derivedValue # [{);
-      return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
+    /// provided map functions. Returns nullptr in the case of failure. See
+    /// `AttrTypeReplacer` for information on the support replacement function types.
+    template <typename... ReplacementFns>
+    }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) {
+      AttrTypeReplacer replacer;
+      (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), ...);
+      return replacer.replace(}] # derivedValue # [{);
     }
   }];
   code extraSharedClassDeclaration = [{
@@ -118,35 +102,6 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
     void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
       walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
     }
-
-    /// Recursively replace all of the nested sub-attributes using the provided
-    /// map function. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn) {
-      return replaceSubElements(
-        replaceAttrFn, [](mlir::Type type) { return type; });
-    }
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn) {
-      return replaceSubElements(
-        replaceAttrFn,
-        [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); }
-      );
-    }
-    /// Recursively replace all of the nested sub-types using the provided map
-    /// function. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
-      return replaceSubElements(
-        [](mlir::Attribute attr) { return attr; }, replaceTypeFn);
-    }
-    }] # attrOrType # [{ replaceSubElements(
-      mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
-      return replaceSubElements(
-        [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); },
-        replaceTypeFn
-      );
-    }
   }];
 }
 
index fd05b9d..88aeaf1 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/SubElementInterfaces.h"
+#include "mlir/IR/Operation.h"
 
 #include "llvm/ADT/DenseSet.h"
 
@@ -91,116 +92,146 @@ void SubElementTypeInterface::walkSubElements(
 }
 
 //===----------------------------------------------------------------------===//
-// ReplaceSubElements
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
 
-template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
-static void updateSubElementImpl(
-    T element, function_ref<std::pair<T, WalkResult>(T)> walkFn,
-    DenseMap<T, T> &visited, SmallVectorImpl<T> &newElements,
-    FailureOr<bool> &changed, ReplaceSubElementFnT &&replaceSubElementFn) {
-  // Bail early if we failed at any point.
-  if (failed(changed))
-    return;
-  newElements.push_back(element);
+void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceLocs,
+                                         bool replaceTypes) {
+  // Functor that replaces the given element if the new value is different,
+  // otherwise returns nullptr.
+  auto replaceIfDifferent = [&](auto element) {
+    auto replacement = replace(element);
+    return (replacement && replacement != element) ? replacement : nullptr;
+  };
+  // Check the attribute dictionary for replacements.
+  if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
+    op->setAttrs(cast<DictionaryAttr>(newAttrs));
 
-  // Guard against potentially null inputs. We always map null to null.
-  if (!element)
+  // If we aren't updating locations or types, we're done.
+  if (!replaceTypes && !replaceLocs)
     return;
 
-  // Check for an existing mapping for this element, and walk it if we haven't
-  // yet.
-  T *mappedElement = &visited[element];
-  if (!*mappedElement) {
-    WalkResult result = WalkResult::advance();
-    std::tie(*mappedElement, result) = walkFn(element);
-
-    // Try walking this element.
-    if (result.wasInterrupted() || !*mappedElement) {
-      changed = failure();
-      return;
-    }
+  // Update the location.
+  if (replaceLocs) {
+    if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
+      op->setLoc(cast<LocationAttr>(newLoc));
+  }
 
-    // Handle replacing sub-elements if this element is also a container.
-    if (!result.wasSkipped()) {
-      if (auto interface = mappedElement->template dyn_cast<InterfaceT>()) {
-        // Cache the size of the `visited` map since it may grow when calling
-        // `replaceSubElementFn` and we would need to fetch again the (now
-        // invalidated) reference to `mappedElement`.
-        size_t visitedSize = visited.size();
-        auto replacedElement = replaceSubElementFn(interface);
-        if (!replacedElement) {
-          changed = failure();
-          return;
+  // Update the result types.
+  if (replaceTypes) {
+    for (OpResult result : op->getResults())
+      if (Type newType = replaceIfDifferent(result.getType()))
+        result.setType(newType);
+  }
+
+  // Update any nested block arguments.
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region) {
+      for (BlockArgument &arg : block.getArguments()) {
+        if (replaceLocs) {
+          if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
+            arg.setLoc(cast<LocationAttr>(newLoc));
+        }
+
+        if (replaceTypes) {
+          if (Type newType = replaceIfDifferent(arg.getType()))
+            arg.setType(newType);
         }
-        if (visitedSize != visited.size())
-          mappedElement = &visited[element];
-        *mappedElement = replacedElement;
       }
     }
   }
+}
+
+template <typename T>
+static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+                                 DenseMap<T, T> &elementMap,
+                                 SmallVectorImpl<T> &newElements,
+                                 FailureOr<bool> &changed) {
+  // Bail early if we failed at any point.
+  if (failed(changed))
+    return;
+
+  // Guard against potentially null inputs. We always map null to null.
+  if (!element) {
+    newElements.push_back(nullptr);
+    return;
+  }
 
-  // Update to the mapped element.
-  if (*mappedElement != element) {
-    newElements.back() = *mappedElement;
-    changed = true;
+  // Replace the element.
+  if (T result = replacer.replace(element)) {
+    newElements.push_back(result);
+    if (result != element)
+      changed = true;
+  } else {
+    changed = failure();
   }
 }
 
-template <typename InterfaceT>
-static typename InterfaceT::ValueType
-replaceSubElementsImpl(InterfaceT interface,
-                       SubElementResultReplFn<Attribute> walkAttrsFn,
-                       SubElementResultReplFn<Type> walkTypesFn,
-                       DenseMap<Attribute, Attribute> &visitedAttrs,
-                       DenseMap<Type, Type> &visitedTypes) {
+template <typename InterfaceT, typename T>
+T AttrTypeReplacer::replaceSubElements(InterfaceT interface,
+                                       DenseMap<T, T> &interfaceMap) {
   // Walk the current sub-elements, replacing them as necessary.
   SmallVector<Attribute, 16> newAttrs;
   SmallVector<Type, 16> newTypes;
   FailureOr<bool> changed = false;
-  auto replaceSubElementFn = [&](auto subInterface) {
-    return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn,
-                                  visitedAttrs, visitedTypes);
-  };
   interface.walkImmediateSubElements(
       [&](Attribute element) {
-        updateSubElementImpl<SubElementAttrInterface>(
-            element, walkAttrsFn, visitedAttrs, newAttrs, changed,
-            replaceSubElementFn);
+        updateSubElementImpl(element, *this, attrMap, newAttrs, changed);
       },
       [&](Type element) {
-        updateSubElementImpl<SubElementTypeInterface>(
-            element, walkTypesFn, visitedTypes, newTypes, changed,
-            replaceSubElementFn);
+        updateSubElementImpl(element, *this, typeMap, newTypes, changed);
       });
   if (failed(changed))
-    return {};
+    return nullptr;
 
-  // If the sub-elements didn't change, just return the original value.
-  if (!*changed)
-    return interface;
+  // If any sub-elements changed, use the new elements during the replacement.
+  T result = interface;
+  if (*changed)
+    result = interface.replaceImmediateSubElements(newAttrs, newTypes);
+  return result;
+}
+
+/// Shared implementation of replacing a given attribute or type element.
+template <typename InterfaceT, typename ReplaceFns, typename T>
+T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns,
+                                DenseMap<T, T> &map) {
+  auto [it, inserted] = map.try_emplace(element, element);
+  if (!inserted)
+    return it->second;
+
+  T result = element;
+  WalkResult walkResult = WalkResult::advance();
+  for (auto &replaceFn : llvm::reverse(replaceFns)) {
+    if (Optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
+      std::tie(result, walkResult) = *newRes;
+      break;
+    }
+  }
+
+  // If an error occurred, return nullptr to indicate failure.
+  if (walkResult.wasInterrupted() || !result)
+    return map[element] = nullptr;
+
+  // Handle replacing sub-elements if this element is also a container.
+  if (!walkResult.wasSkipped()) {
+    if (auto interface = dyn_cast<InterfaceT>(result)) {
+      // Replace the sub elements of this element, bailing if we fail.
+      if (!(result = replaceSubElements(interface, map)))
+        return map[element] = nullptr;
+    }
+  }
 
-  // Use the new elements during the replacement.
-  return interface.replaceImmediateSubElements(newAttrs, newTypes);
+  return map[element] = result;
 }
 
-Attribute SubElementAttrInterface::replaceSubElements(
-    SubElementResultReplFn<Attribute> replaceAttrFn,
-    SubElementResultReplFn<Type> replaceTypeFn) {
-  assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
-  DenseMap<Attribute, Attribute> visitedAttrs;
-  DenseMap<Type, Type> visitedTypes;
-  return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
-                                visitedAttrs, visitedTypes);
+Attribute AttrTypeReplacer::replace(Attribute attr) {
+  return replaceImpl<SubElementAttrInterface>(attr, attrReplacementFns,
+                                              attrMap);
 }
 
-Type SubElementTypeInterface::replaceSubElements(
-    SubElementResultReplFn<Attribute> replaceAttrFn,
-    SubElementResultReplFn<Type> replaceTypeFn) {
-  assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
-  DenseMap<Attribute, Attribute> visitedAttrs;
-  DenseMap<Type, Type> visitedTypes;
-  return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
-                                visitedAttrs, visitedTypes);
+Type AttrTypeReplacer::replace(Type type) {
+  return replaceImpl<SubElementTypeInterface>(type, typeReplacementFns,
+                                              typeMap);
 }
 
 //===----------------------------------------------------------------------===//
index 1acaea7..2874ddc 100644 (file)
@@ -853,40 +853,31 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
     SymbolRefAttr oldAttr = scope.symbol;
     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
-
-    auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
-      auto remapAttrFn =
-          [&](Attribute attr) -> std::pair<Attribute, WalkResult> {
-        // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
-        // want to accidentally replace an inner reference.
-        if (attr == oldAttr)
-          return {newAttr, WalkResult::skip()};
-        // Handle prefix matches.
-        if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
-          if (isReferencePrefixOf(oldAttr, symRef)) {
+    AttrTypeReplacer replacer;
+    replacer.addReplacement(
+        [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
+          // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
+          // want to accidentally replace an inner reference.
+          if (attr == oldAttr)
+            return {newAttr, WalkResult::skip()};
+          // Handle prefix matches.
+          if (isReferencePrefixOf(oldAttr, attr)) {
             auto oldNestedRefs = oldAttr.getNestedReferences();
-            auto nestedRefs = symRef.getNestedReferences();
+            auto nestedRefs = attr.getNestedReferences();
             if (oldNestedRefs.empty())
               return {SymbolRefAttr::get(newSymbol, nestedRefs),
                       WalkResult::skip()};
 
             auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
             newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
-            return {
-                SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs),
-                WalkResult::skip()};
+            return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
+                    WalkResult::skip()};
           }
           return {attr, WalkResult::skip()};
-        }
-        return {attr, WalkResult::advance()};
-      };
-      // Generate a new attribute dictionary by replacing references to the old
-      // symbol.
-      auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
-      if (!newDict)
-        return WalkResult::interrupt();
-
-      op->setAttrs(newDict.template cast<DictionaryAttr>());
+        });
+
+    auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
+      replacer.replaceElementsIn(op);
       return WalkResult::advance();
     };
     if (!scope.walkSymbolTable(walkFn))