[mlir][IR] Add Iterator template option to IR walkers
authorMatthias Springer <me@m-sp.org>
Fri, 24 Feb 2023 09:13:36 +0000 (10:13 +0100)
committerMatthias Springer <me@m-sp.org>
Fri, 24 Feb 2023 09:24:47 +0000 (10:24 +0100)
This allows users to specify a  top-down or bottom-up traversal of the IR, in addition to the already existing WalkOrder.

Certain transformations work better with a forward traversal. E.g., when cloning a piece of IR, operations should be cloned top-down so that all uses are defined when creating an op.

Certain transformations work better with a reverse traversal. E.g., when erasing a piece of IR, operations should be erased bottom-up to avoid erasing operations that still have users.

Differential Revision: https://reviews.llvm.org/D144257

mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/IR/Visitors.cpp

index 6c753cb..aadcca0 100644 (file)
@@ -258,42 +258,47 @@ public:
 
   /// Walk the operations in this block. The callback method is called for each
   /// nested region, block or operation, depending on the callback provided.
-  /// Regions, blocks and operations at the same nesting level are visited in
-  /// lexicographical order. The walk order for enclosing regions, blocks and
-  /// operations with respect to their nested ones is specified by 'Order'
+  /// The order in which regions, blocks and operations at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
+  /// and operations with respect to their nested ones is specified by 'Order'
   /// (post-order by default). A callback on a block or operation is allowed to
   /// erase that block or operation if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return walk<Order>(begin(), end(), std::forward<FnT>(callback));
+    return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
   }
 
   /// Walk the operations in the specified [begin, end) range of this block. The
   /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. Regions, blocks and operations at the
-  /// same nesting level are visited in lexicographical order. The walk order
+  /// depending on the callback provided. The order in which regions, blocks and
+  /// operations at the same nesting level are visited (e.g., lexicographical or
+  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
   /// for enclosing regions, blocks and operations with respect to their nested
   /// ones is specified by 'Order' (post-order by default). This method is
   /// invoked for void-returning callbacks. A callback on a block or operation
   /// is allowed to erase that block or operation only if the walk is in
   /// post-order. See non-void method for pre-order erasure.
   /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<std::is_same<RetT, void>::value, RetT>
   walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
     for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      detail::walk<Order>(&op, callback);
+      detail::walk<Order, Iterator>(&op, callback);
   }
 
   /// Walk the operations in the specified [begin, end) range of this block. The
   /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. Regions, blocks and operations at the
-  /// same nesting level are visited in lexicographical order. The walk order
+  /// depending on the callback provided. The order in which regions, blocks and
+  /// operations at the same nesting level are visited (e.g., lexicographical or
+  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
   /// for enclosing regions, blocks and operations with respect to their nested
   /// ones is specified by 'Order' (post-order by default). This method is
   /// invoked for skippable or interruptible callbacks. A callback on a block or
@@ -301,12 +306,13 @@ public:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
   walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
     for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      if (detail::walk<Order>(&op, callback).wasInterrupted())
+      if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
         return WalkResult::interrupt();
     return WalkResult::advance();
   }
index f7d8436..a86acaf 100644 (file)
@@ -132,20 +132,22 @@ public:
 
   /// Walk the operation by calling the callback for each nested operation
   /// (including this one), block or region, depending on the callback provided.
-  /// Regions, blocks and operations at the same nesting level are visited in
-  /// lexicographical order. The walk order for enclosing regions, blocks and
-  /// operations with respect to their nested ones is specified by 'Order'
+  /// The order in which regions, blocks and operations the same nesting level
+  /// are visited (e.g., lexicographical or reverse lexicographical order) is
+  /// determined by 'Iterator'. The walk order for enclosing regions, blocks
+  /// and operations with respect to their nested ones is specified by 'Order'
   /// (post-order by default). A callback on a block or operation is allowed to
   /// erase that block or operation if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1,
                    RetT>
   walk(FnT &&callback) {
-    return state->walk<Order>(std::forward<FnT>(callback));
+    return state->walk<Order, Iterator>(std::forward<FnT>(callback));
   }
 
   /// Generic walker with a stage aware callback. Walk the operation by calling
index ac6bdfc..250ae95 100644 (file)
@@ -607,9 +607,10 @@ public:
 
   /// Walk the operation by calling the callback for each nested operation
   /// (including this one), block or region, depending on the callback provided.
-  /// Regions, blocks and operations at the same nesting level are visited in
-  /// lexicographical order. The walk order for enclosing regions, blocks and
-  /// operations with respect to their nested ones is specified by 'Order'
+  /// The order in which regions, blocks and operations at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
+  /// and operations with respect to their nested ones is specified by 'Order'
   /// (post-order by default). A callback on a block or operation is allowed to
   /// erase that block or operation if either:
   ///   * the walk is in post-order, or
@@ -631,12 +632,13 @@ public:
   ///           return WalkResult::interrupt();
   ///         return WalkResult::advance();
   ///       });
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1,
                    RetT>
   walk(FnT &&callback) {
-    return detail::walk<Order>(this, std::forward<FnT>(callback));
+    return detail::walk<Order, Iterator>(this, std::forward<FnT>(callback));
   }
 
   /// Generic walker with a stage aware callback. Walk the operation by calling
index 7b2927c..4f4812d 100644 (file)
@@ -265,37 +265,41 @@ public:
 
   /// Walk the operations in this region. The callback method is called for each
   /// nested region, block or operation, depending on the callback provided.
-  /// Regions, blocks and operations at the same nesting level are visited in
-  /// lexicographical order. The walk order for enclosing regions, blocks and
-  /// operations with respect to their nested ones is specified by 'Order'
+  /// The order in which regions, blocks and operations at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
+  /// and operations with respect to their nested ones is specified by 'Order'
   /// (post-order by default). This method is invoked for void-returning
   /// callbacks. A callback on a block or operation is allowed to erase that
   /// block or operation only if the walk is in post-order. See non-void method
   /// for pre-order erasure. See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
     for (auto &block : *this)
-      block.walk<Order>(callback);
+      block.walk<Order, Iterator>(callback);
   }
 
   /// Walk the operations in this region. The callback method is called for each
   /// nested region, block or operation, depending on the callback provided.
-  /// Regions, blocks and operations at the same nesting level are visited in
-  /// lexicographical order. The walk order for enclosing regions, blocks and
-  /// operations with respect to their nested ones is specified by 'Order'
+  /// The order in which regions, blocks and operations at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
+  /// and operations with respect to their nested ones is specified by 'Order'
   /// (post-order by default). This method is invoked for skippable or
   /// interruptible callbacks. A callback on a block or operation is allowed to
   /// erase that block or operation if either:
   ///   * the walk is in post-order,
   ///   * or the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+  template <WalkOrder Order = WalkOrder::PostOrder,
+            typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
   walk(FnT &&callback) {
     for (auto &block : *this)
-      if (block.walk<Order>(callback).wasInterrupted())
+      if (block.walk<Order, Iterator>(callback).wasInterrupted())
         return WalkResult::interrupt();
     return WalkResult::advance();
   }
index 51e7933..3fc1055 100644 (file)
@@ -62,6 +62,24 @@ public:
 /// Traversal order for region, block and operation walk utilities.
 enum class WalkOrder { PreOrder, PostOrder };
 
+/// This iterator enumerates the elements in "forward" order.
+struct ForwardIterator {
+  template <typename RangeT>
+  static constexpr RangeT &makeRange(RangeT &range) {
+    return range;
+  }
+};
+
+/// This iterator enumerates elements in "reverse" order. It is a wrapper around
+/// llvm::reverse.
+struct ReverseIterator {
+  template <typename RangeT>
+  static constexpr auto makeRange(RangeT &&range) {
+    // llvm::reverse uses RangeT::rbegin and RangeT::rend.
+    return llvm::reverse(std::forward<RangeT>(range));
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
@@ -113,31 +131,39 @@ template <typename T>
 using first_argument = decltype(first_argument_type(std::declval<T>()));
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. Regions, blocks and operations at the same nesting
-/// level are visited in lexicographical order. The walk order for enclosing
-/// regions, blocks and operations with respect to their nested ones is
-/// specified by 'order'. These methods are invoked for void-returning
+/// the given operation. The order in which regions, blocks and operations at
+/// the same nesting level are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions, blocks and operations with respect to their nested ones
+/// is specified by 'order'. These methods are invoked for void-returning
 /// callbacks. A callback on a block or operation is allowed to erase that block
 /// or operation only if the walk is in post-order. See non-void method for
 /// pre-order erasure.
+template <typename Iterator>
 void walk(Operation *op, function_ref<void(Region *)> callback,
           WalkOrder order);
+template <typename Iterator>
 void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
+template <typename Iterator>
 void walk(Operation *op, function_ref<void(Operation *)> callback,
           WalkOrder order);
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. Regions, blocks and operations at the same nesting
-/// level are visited in lexicographical order. The walk order for enclosing
-/// regions, blocks and operations with respect to their nested ones is
-/// specified by 'order'. This method is invoked for skippable or interruptible
-/// callbacks. A callback on a block or operation is allowed to erase that block
-/// or operation if either:
+/// the given operation. The order in which regions, blocks and operations at
+/// the same nesting level are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions, blocks and operations with respect to their nested ones
+/// is specified by 'order'. This method is invoked for skippable or
+/// interruptible callbacks. A callback on a block or operation is allowed to
+/// erase that block or operation if either:
 ///   * the walk is in post-order, or
 ///   * the walk is in pre-order and the walk is skipped after the erasure.
+template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
                 WalkOrder order);
+template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
                 WalkOrder order);
+template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
                 WalkOrder order);
 
@@ -147,10 +173,11 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
 // upon the type of the callback function.
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. Regions, blocks and operations at the same nesting
-/// level are visited in lexicographical order. The walk order for enclosing
-/// regions, blocks and operations with respect to their nested ones is
-/// specified by 'Order' (post-order by default). A callback on a block or
+/// the given operation. The order in which regions, blocks and operations at
+/// the same nesting level are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions, blocks and operations with respect to their nested ones
+/// is specified by 'Order' (post-order by default). A callback on a block or
 /// operation is allowed to erase that block or operation if either:
 ///   * the walk is in post-order, or
 ///   * the walk is in pre-order and the walk is skipped after the erasure.
@@ -162,20 +189,21 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
 ///   op->walk([](Block *b) { ... });
 ///   op->walk([](Operation *op) { ... });
 template <
-    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
-    typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
+    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
                  RetT>
 walk(Operation *op, FuncTy &&callback) {
-  return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
+  return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. Regions, blocks and operations at the same nesting
-/// level are visited in lexicographical order. The walk order for enclosing
-/// regions, blocks and operations with respect to their nested ones is
-/// specified by 'order' (post-order by default). This method is selected for
+/// given operation. The order in which regions, blocks and operations at
+/// the same nesting are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions, blocks and operations with respect to their nested ones
+/// is specified by 'order' (post-order by default). This method is selected for
 /// void-returning callbacks that operate on a specific derived operation type.
 /// A callback on an operation is allowed to erase that operation only if the
 /// walk is in post-order. See non-void method for pre-order erasure.
@@ -183,8 +211,8 @@ walk(Operation *op, FuncTy &&callback) {
 /// Example:
 ///   op->walk([](ReturnOp op) { ... });
 template <
-    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
-    typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
+    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 std::enable_if_t<
     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -195,17 +223,19 @@ walk(Operation *op, FuncTy &&callback) {
     if (auto derivedOp = dyn_cast<ArgT>(op))
       callback(derivedOp);
   };
-  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
+  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
+                                Order);
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. Regions, blocks and operations at the same nesting level
-/// are visited in lexicographical order. The walk order for enclosing regions,
-/// blocks and operations with respect to their nested ones is specified by
-/// 'Order' (post-order by default). This method is selected for WalkReturn
-/// returning skippable or interruptible callbacks that operate on a specific
-/// derived operation type. A callback on an operation is allowed to erase that
-/// operation if either:
+/// given operation. The order in which regions, blocks and operations at
+/// the same nesting are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions, blocks and operations with respect to their nested ones
+/// is specified by 'Order' (post-order by default). This method is selected for
+/// WalkReturn returning skippable or interruptible callbacks that operate on a
+/// specific derived operation type. A callback on an operation is allowed to
+/// erase that operation if either:
 ///   * the walk is in post-order, or
 ///   * the walk is in pre-order and the walk is skipped after the erasure.
 ///
@@ -218,8 +248,8 @@ walk(Operation *op, FuncTy &&callback) {
 ///     return WalkResult::advance();
 ///   });
 template <
-    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
-    typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
+    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 std::enable_if_t<
     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -231,7 +261,8 @@ walk(Operation *op, FuncTy &&callback) {
       return callback(derivedOp);
     return WalkResult::advance();
   };
-  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
+  return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
+                                Order);
 }
 
 /// Generic walkers with stage aware callbacks.
index 74d6750..a54eca0 100644 (file)
@@ -15,60 +15,91 @@ WalkStage::WalkStage(Operation *op)
     : numRegions(op->getNumRegions()), nextRegion(0) {}
 
 /// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. Regions, blocks and operations at the same nesting level
-/// are visited in lexicographical order. The walk order for enclosing regions,
-/// blocks and operations with respect to their nested ones is specified by
-/// 'order'. These methods are invoked for void-returning callbacks. A callback
-/// on a block or operation is allowed to erase that block or operation only if
-/// the walk is in post-order. See non-void method for pre-order erasure.
+/// given operation. The order in which regions, blocks and operations at the
+/// same nesting level are visited (e.g., lexicographical or reverse
+/// lexicographical order) is determined by 'Iterator'. The walk order for
+/// enclosing regions,  blocks and operations with respect to their nested ones
+/// is specified by 'order'. These methods are invoked for void-returning
+/// callbacks. A callback on a block or operation is allowed to erase that block
+/// or operation only if the walk is in post-order. See non-void method for
+/// pre-order erasure.
+template <typename Iterator>
 void detail::walk(Operation *op, function_ref<void(Region *)> callback,
                   WalkOrder order) {
   // We don't use early increment for regions because they can't be erased from
   // a callback.
-  for (auto &region : op->getRegions()) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
     if (order == WalkOrder::PreOrder)
       callback(&region);
-    for (auto &block : region) {
-      for (auto &nestedOp : block)
-        walk(&nestedOp, callback, order);
+    for (auto &block : Iterator::makeRange(region)) {
+      for (auto &nestedOp : Iterator::makeRange(block))
+        walk<Iterator>(&nestedOp, callback, order);
     }
     if (order == WalkOrder::PostOrder)
       callback(&region);
   }
 }
+// Explicit template instantiations for all supported iterators.
+template void detail::walk<ForwardIterator>(Operation *,
+                                            function_ref<void(Region *)>,
+                                            WalkOrder);
+template void detail::walk<ReverseIterator>(Operation *,
+                                            function_ref<void(Region *)>,
+                                            WalkOrder);
 
+template <typename Iterator>
 void detail::walk(Operation *op, function_ref<void(Block *)> callback,
                   WalkOrder order) {
-  for (auto &region : op->getRegions()) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
     // Early increment here in the case where the block is erased.
-    for (auto &block : llvm::make_early_inc_range(region)) {
+    for (auto &block :
+         llvm::make_early_inc_range(Iterator::makeRange(region))) {
       if (order == WalkOrder::PreOrder)
         callback(&block);
-      for (auto &nestedOp : block)
-        walk(&nestedOp, callback, order);
+      for (auto &nestedOp : Iterator::makeRange(block))
+        walk<Iterator>(&nestedOp, callback, order);
       if (order == WalkOrder::PostOrder)
         callback(&block);
     }
   }
 }
+// Explicit template instantiations for all supported iterators.
+template void detail::walk<ForwardIterator>(Operation *,
+                                            function_ref<void(Block *)>,
+                                            WalkOrder);
+template void detail::walk<ReverseIterator>(Operation *,
+                                            function_ref<void(Block *)>,
+                                            WalkOrder);
 
+template <typename Iterator>
 void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
                   WalkOrder order) {
   if (order == WalkOrder::PreOrder)
     callback(op);
 
   // TODO: This walk should be iterative over the operations.
-  for (auto &region : op->getRegions()) {
-    for (auto &block : region) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
+    for (auto &block : Iterator::makeRange(region)) {
       // Early increment here in the case where the operation is erased.
-      for (auto &nestedOp : llvm::make_early_inc_range(block))
-        walk(&nestedOp, callback, order);
+      for (auto &nestedOp :
+           llvm::make_early_inc_range(Iterator::makeRange(block)))
+        walk<Iterator>(&nestedOp, callback, order);
     }
   }
 
   if (order == WalkOrder::PostOrder)
     callback(op);
 }
+// Explicit template instantiations for all supported iterators.
+template void detail::walk<ForwardIterator>(Operation *,
+                                            function_ref<void(Operation *)>,
+                                            WalkOrder);
+template void detail::walk<ReverseIterator>(Operation *,
+                                            function_ref<void(Operation *)>,
+                                            WalkOrder);
 
 void detail::walk(Operation *op,
                   function_ref<void(Operation *, const WalkStage &)> callback) {
@@ -99,12 +130,14 @@ void detail::walk(Operation *op,
 /// operation is allowed to erase that block or operation if either:
 ///   * the walk is in post-order, or
 ///   * the walk is in pre-order and the walk is skipped after the erasure.
+template <typename Iterator>
 WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Region *)> callback,
                         WalkOrder order) {
   // We don't use early increment for regions because they can't be erased from
   // a callback.
-  for (auto &region : op->getRegions()) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
     if (order == WalkOrder::PreOrder) {
       WalkResult result = callback(&region);
       if (result.wasSkipped())
@@ -112,9 +145,9 @@ WalkResult detail::walk(Operation *op,
       if (result.wasInterrupted())
         return WalkResult::interrupt();
     }
-    for (auto &block : region) {
-      for (auto &nestedOp : block)
-        if (walk(&nestedOp, callback, order).wasInterrupted())
+    for (auto &block : Iterator::makeRange(region)) {
+      for (auto &nestedOp : Iterator::makeRange(block))
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
           return WalkResult::interrupt();
     }
     if (order == WalkOrder::PostOrder) {
@@ -126,13 +159,23 @@ WalkResult detail::walk(Operation *op,
   }
   return WalkResult::advance();
 }
+// Explicit template instantiations for all supported iterators.
+template WalkResult
+detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Region *)>,
+                              WalkOrder);
+template WalkResult
+detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Region *)>,
+                              WalkOrder);
 
+template <typename Iterator>
 WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Block *)> callback,
                         WalkOrder order) {
-  for (auto &region : op->getRegions()) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
     // Early increment here in the case where the block is erased.
-    for (auto &block : llvm::make_early_inc_range(region)) {
+    for (auto &block :
+         llvm::make_early_inc_range(Iterator::makeRange(region))) {
       if (order == WalkOrder::PreOrder) {
         WalkResult result = callback(&block);
         if (result.wasSkipped())
@@ -140,8 +183,8 @@ WalkResult detail::walk(Operation *op,
         if (result.wasInterrupted())
           return WalkResult::interrupt();
       }
-      for (auto &nestedOp : block)
-        if (walk(&nestedOp, callback, order).wasInterrupted())
+      for (auto &nestedOp : Iterator::makeRange(block))
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
           return WalkResult::interrupt();
       if (order == WalkOrder::PostOrder) {
         if (callback(&block).wasInterrupted())
@@ -153,7 +196,15 @@ WalkResult detail::walk(Operation *op,
   }
   return WalkResult::advance();
 }
+// Explicit template instantiations for all supported iterators.
+template WalkResult
+detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Block *)>,
+                              WalkOrder);
+template WalkResult
+detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Block *)>,
+                              WalkOrder);
 
+template <typename Iterator>
 WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Operation *)> callback,
                         WalkOrder order) {
@@ -167,11 +218,13 @@ WalkResult detail::walk(Operation *op,
   }
 
   // TODO: This walk should be iterative over the operations.
-  for (auto &region : op->getRegions()) {
-    for (auto &block : region) {
+  MutableArrayRef<Region> regions = op->getRegions();
+  for (auto &region : Iterator::makeRange(regions)) {
+    for (auto &block : Iterator::makeRange(region)) {
       // Early increment here in the case where the operation is erased.
-      for (auto &nestedOp : llvm::make_early_inc_range(block)) {
-        if (walk(&nestedOp, callback, order).wasInterrupted())
+      for (auto &nestedOp :
+           llvm::make_early_inc_range(Iterator::makeRange(block))) {
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
           return WalkResult::interrupt();
       }
     }
@@ -181,6 +234,13 @@ WalkResult detail::walk(Operation *op,
     return callback(op);
   return WalkResult::advance();
 }
+// Explicit template instantiations for all supported iterators.
+template WalkResult
+detail::walk<ForwardIterator>(Operation *,
+                              function_ref<WalkResult(Operation *)>, WalkOrder);
+template WalkResult
+detail::walk<ReverseIterator>(Operation *,
+                              function_ref<WalkResult(Operation *)>, WalkOrder);
 
 WalkResult detail::walk(
     Operation *op,