NFC: Cleanup the implementation of walkSymbolUses.
authorRiver Riddle <riverriddle@google.com>
Sat, 19 Oct 2019 04:28:47 +0000 (21:28 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 19 Oct 2019 04:29:15 +0000 (21:29 -0700)
Refactor the implementation to be much cleaner by adding a `make_second_range` utility to walk the `second` value of a range of pairs.

PiperOrigin-RevId: 275598985

mlir/include/mlir/Support/STLExtras.h
mlir/lib/IR/SymbolTable.cpp

index 6ee2754..2ab890e 100644 (file)
@@ -24,8 +24,7 @@
 #define MLIR_SUPPORT_STLEXTRAS_H
 
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/iterator.h"
-#include <tuple>
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 
@@ -185,6 +184,15 @@ protected:
   ptrdiff_t index;
 };
 
+/// Given a container of pairs, return a range over the second elements.
+template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
+  return llvm::map_range(
+      std::forward<ContainerTy>(c),
+      [](decltype((*std::begin(c))) elt) -> decltype((elt.second)) {
+        return elt.second;
+      });
+}
+
 } // end namespace mlir
 
 // Allow tuples to be usable as DenseMap keys.
index 181acd5..144be24 100644 (file)
@@ -163,44 +163,6 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
 // SymbolTable Trait Types
 //===----------------------------------------------------------------------===//
 
-/// A utility result for walking a nested attribute for symbol uses.
-enum HandlerResult {
-  /// The walk of the containter can continue.
-  Continue = 0,
-  /// The walk should recurse into the given attribute, as it is a container.
-  RecurseNestedAttribute,
-  /// The walk should end immediately, as an interrupt has been signaled.
-  Interrupt
-};
-
-/// Utility function used to handle a nested attribute during a walk of symbol
-/// uses. It returns the above HandlerResult signaling the next action for the
-/// walk.
-HandlerResult handleAttrDuringSymbolWalk(
-    Operation *op, Attribute attr,
-    SmallVectorImpl<std::pair<Attribute, unsigned>> &worklist,
-    function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
-  switch (attr.getKind()) {
-  /// Check for a nested container attribute, these will also need to be
-  /// walked.
-  case StandardAttributes::Array:
-  case StandardAttributes::Dictionary: {
-    worklist.push_back({attr, /*index*/ 0});
-    return HandlerResult::RecurseNestedAttribute;
-  }
-
-  // Invoke the provided callback if we find a symbol use and check for a
-  // requested interrupt.
-  case StandardAttributes::SymbolRef: {
-    SymbolTable::SymbolUse use(op, attr.cast<SymbolRefAttr>());
-    return callback(use).wasInterrupted() ? HandlerResult::Interrupt
-                                          : HandlerResult::Continue;
-  }
-  default:
-    return HandlerResult::Continue;
-  }
-}
-
 /// Walk all of the symbol references within the given operation, invoking the
 /// provided callback for each found use.
 static WalkResult
@@ -215,37 +177,44 @@ walkSymbolRefs(Operation *op,
   // attribute list.
   SmallVector<std::pair<Attribute, unsigned>, 1> worklist;
   worklist.push_back({attrDict, /*index*/ 0});
-  while (!worklist.empty()) {
-    Attribute attr = worklist.back().first;
-    unsigned &index = worklist.back().second;
 
-    // Iterate over the given attribute, which is guaranteed to be a container.
-    HandlerResult handlerResult = HandlerResult::Continue;
-    if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
-      ArrayRef<Attribute> attrs = arrayAttr.getValue();
-      unsigned attrSize = attrs.size();
-      while (index != attrSize)
-        if ((handlerResult = handleAttrDuringSymbolWalk(op, attrs[index++],
-                                                        worklist, callback)))
-          break;
-    } else {
-      auto dictAttr = attr.cast<DictionaryAttr>();
-      ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
-      unsigned attrSize = attrs.size();
-      while (index != attrSize)
-        if ((handlerResult = handleAttrDuringSymbolWalk(
-                 op, attrs[index++].second, worklist, callback)))
-          break;
+  // Process the symbol references within the given nested attribute range.
+  auto processAttrs = [&](unsigned &index, auto attrRange) -> WalkResult {
+    for (Attribute attr : llvm::drop_begin(attrRange, index)) {
+      // Make sure to keep the index counter in sync.
+      ++index;
+
+      /// Check for a nested container attribute, these will also need to be
+      /// walked.
+      if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
+        worklist.push_back({attr, /*index*/ 0});
+        return WalkResult::advance();
+      }
+
+      // Invoke the provided callback if we find a symbol use and check for a
+      // requested interrupt.
+      if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>())
+        if (callback(SymbolTable::SymbolUse(op, symbolRef)).wasInterrupted())
+          return WalkResult::interrupt();
     }
-    if (handlerResult == HandlerResult::Interrupt)
-      return WalkResult::interrupt();
 
-    // If we didn't encounter a nested attribute, pop the last item from the
-    // worklist.
-    if (handlerResult != HandlerResult::RecurseNestedAttribute)
-      worklist.pop_back();
-  }
-  return WalkResult::advance();
+    // Pop this container attribute from the worklist.
+    worklist.pop_back();
+    return WalkResult::advance();
+  };
+
+  WalkResult result = WalkResult::advance();
+  do {
+    Attribute attr = worklist.back().first;
+    unsigned &index = worklist.back().second;
+
+    // Process the given attribute, which is guaranteed to be a container.
+    if (auto dict = attr.dyn_cast<DictionaryAttr>())
+      result = processAttrs(index, make_second_range(dict.getValue()));
+    else
+      result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
+  } while (!worklist.empty() && !result.wasInterrupted());
+  return result;
 }
 
 /// Walk all of the uses, for any symbol, that are nested within the given