Add support for early exit walk methods.
authorRiver Riddle <riverriddle@google.com>
Fri, 30 Aug 2019 19:47:24 +0000 (12:47 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 30 Aug 2019 19:47:53 +0000 (12:47 -0700)
This is done by providing a walk callback that returns a WalkResult. This result is either `advance` or `interrupt`. `advance` means that the walk should continue, whereas `interrupt` signals that the walk should stop immediately. An example is shown below:

auto result = op->walk([](Operation *op) {
  if (some_invariant)
    return WalkResult::interrupt();
  return WalkResult::advance();
});

if (result.wasInterrupted())
  ...;

PiperOrigin-RevId: 266436700

mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/Visitors.cpp
mlir/lib/Transforms/Utils/LoopFusionUtils.cpp

index 31b5120..decf4cc 100644 (file)
@@ -291,7 +291,8 @@ public:
   /// Walk the operations in this block in postorder, calling the callback for
   /// each operation.
   /// See Operation::walk for more details.
-  template <typename FnT> void walk(FnT &&callback) {
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  RetT walk(FnT &&callback) {
     return walk(begin(), end(), std::forward<FnT>(callback));
   }
 
@@ -299,10 +300,24 @@ public:
   /// postorder, calling the callback for each operation. This method is invoked
   /// for void return callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT>
-  void walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
+  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
     for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      detail::walkOperations(&op, std::forward<FnT>(callback));
+      detail::walkOperations(&op, callback);
+  }
+
+  /// Walk the operations in the specified [begin, end) range of this block in
+  /// postorder, calling the callback for each operation. This method is invoked
+  /// for interruptible callbacks.
+  /// See Operation::walk for more details.
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
+  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
+      if (detail::walkOperations(&op, callback).wasInterrupted())
+        return WalkResult::interrupt();
+    return WalkResult::advance();
   }
 
   //===--------------------------------------------------------------------===//
index 18c4615..883e994 100644 (file)
@@ -190,8 +190,9 @@ public:
   /// Walk the operation in postorder, calling the callback for each nested
   /// operation(including this one).
   /// See Operation::walk for more details.
-  template <typename FnT> void walk(FnT &&callback) {
-    state->walk(std::forward<FnT>(callback));
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  RetT walk(FnT &&callback) {
+    return state->walk(std::forward<FnT>(callback));
   }
 
   // These are default implementations of customization hooks.
index c3f4c0e..97225d0 100644 (file)
@@ -494,12 +494,21 @@ public:
   /// Walk the operation in postorder, calling the callback for each nested
   /// operation(including this one). The callback method can take any of the
   /// following forms:
-  ///   (void)(Operation*) : Walk all operations opaquely.
+  ///   void(Operation*) : Walk all operations opaquely.
   ///     * op->walk([](Operation *nestedOp) { ...});
-  ///   (void)(OpT) : Walk all operations of the given derived type.
+  ///   void(OpT) : Walk all operations of the given derived type.
   ///     * op->walk([](ReturnOp returnOp) { ...});
-  template <typename FnT> void walk(FnT &&callback) {
-    detail::walkOperations(this, std::forward<FnT>(callback));
+  ///   WalkResult(Operation*|OpT) : Walk operations, but allow for
+  ///                                interruption/cancellation.
+  ///     * op->walk([](... op) {
+  ///         // Interrupt, i.e cancel, the walk based on some invariant.
+  ///         if (some_invariant)
+  ///           return WalkResult::interrupt();
+  ///         return WalkResult::advance();
+  ///       });
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  RetT walk(FnT &&callback) {
+    return detail::walkOperations(this, std::forward<FnT>(callback));
   }
 
   //===--------------------------------------------------------------------===//
index 8b7db9a..96aeafb 100644 (file)
@@ -125,13 +125,27 @@ public:
   void dropAllReferences();
 
   /// Walk the operations in this region in postorder, calling the callback for
-  /// each operation.
+  /// each operation. This method is invoked for void-returning callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT> void walk(FnT &&callback) {
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
+  walk(FnT &&callback) {
     for (auto &block : *this)
       block.walk(callback);
   }
 
+  /// Walk the operations in this region in postorder, calling the callback for
+  /// each operation. This method is invoked for interruptible callbacks.
+  /// See Operation::walk for more details.
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
+  walk(FnT &&callback) {
+    for (auto &block : *this)
+      if (block.walk(callback).wasInterrupted())
+        return WalkResult::interrupt();
+    return WalkResult::advance();
+  }
+
   /// Displays the CFG in a window. This is for use from the debugger and
   /// depends on Graphviz to generate the graph.
   /// This function is defined in ViewRegionGraph and only works with that
index f1876df..395a4e7 100644 (file)
 #define MLIR_IR_VISITORS_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
+class Diagnostic;
+class InFlightDiagnostic;
 class Operation;
 
+/// A utility result that is used to signal if a walk method should be
+/// interrupted or advance.
+class WalkResult {
+  enum ResultEnum { Interrupt, Advance } result;
+
+public:
+  WalkResult(ResultEnum result) : result(result) {}
+
+  /// Allow LogicalResult to interrupt the walk on failure.
+  WalkResult(LogicalResult result)
+      : result(failed(result) ? Interrupt : Advance) {}
+
+  /// Allow diagnostics to interrupt the walk.
+  WalkResult(Diagnostic &&) : result(Interrupt) {}
+  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
+
+  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
+
+  static WalkResult interrupt() { return {Interrupt}; }
+  static WalkResult advance() { return {Advance}; }
+
+  /// Returns if the walk was interrupted.
+  bool wasInterrupted() const { return result == Interrupt; }
+};
+
 namespace detail {
 /// Helper templates to deduce the first argument of a callback parameter.
 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
@@ -45,9 +73,15 @@ using first_argument = decltype(first_argument_type(std::declval<T>()));
 /// Walk all of the operations nested under and including the given operation.
 void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
 
+/// Walk all of the operations nested under and including the given operation.
+/// This methods walks operations until an interrupt result is returned by the
+/// callback.
+WalkResult walkOperations(Operation *op,
+                          function_ref<WalkResult(Operation *op)> callback);
+
 // Below are a set of functions to walk nested operations. Users should favor
 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
-// methods. They are also templated to allow for dynamically dispatching based
+// methods. They are also templated to allow for statically dispatching based
 // upon the type of the callback function.
 
 /// Walk all of the operations nested under and including the given operation.
@@ -65,7 +99,7 @@ walkOperations(Operation *op, FuncTy &&callback) {
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
 /// given operation. This method is selected for void returning callbacks that
-/// operation on a specific derived operation type.
+/// operate on a specific derived operation type.
 ///
 /// Example:
 ///   op->walk([](ReturnOp op) { ... });
@@ -82,6 +116,35 @@ walkOperations(Operation *op, FuncTy &&callback) {
   };
   return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
 }
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for WalkReturn returning
+/// interruptible callbacks that operate on a specific derived operation type.
+///
+/// Example:
+///   op->walk([](ReturnOp op) {
+///     if (some_invariant)
+///       return WalkResult::interrupt();
+///     return WalkResult::advance();
+///   });
+template <
+    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+    typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+                            std::is_same<RetT, WalkResult>::value,
+                        RetT>::type
+walkOperations(Operation *op, FuncTy &&callback) {
+  auto wrapperFn = [&](Operation *op) {
+    if (auto derivedOp = dyn_cast<ArgT>(op))
+      return callback(derivedOp);
+    return WalkResult::advance();
+  };
+  return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
+}
+
+/// Utility to provide the return type of a templated walk method.
+template <typename FnT>
+using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
 } // end namespace detail
 
 } // namespace mlir
index aaefd98..660b77e 100644 (file)
@@ -905,11 +905,10 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
   SmallDenseMap<Value *, std::unique_ptr<MemRefRegion>, 4> regions;
 
   // Walk this 'affine.for' operation to gather all memory regions.
-  bool error = false;
-  block.walk(start, end, [&](Operation *opInst) {
+  auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
     if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
       // Neither load nor a store op.
-      return;
+      return WalkResult::advance();
     }
 
     // Compute the memref region symbolic in any IVs enclosing this block.
@@ -917,23 +916,20 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
     if (failed(
             region->compute(opInst,
                             /*loopDepth=*/getNestingDepth(*block.begin())))) {
-      opInst->emitError("Error obtaining memory region\n");
-      error = true;
-      return;
+      return opInst->emitError("Error obtaining memory region\n");
     }
+
     auto it = regions.find(region->memref);
     if (it == regions.end()) {
       regions[region->memref] = std::move(region);
     } else if (failed(it->second->unionBoundingBox(*region))) {
-      opInst->emitWarning(
+      return opInst->emitWarning(
           "getMemoryFootprintBytes: unable to perform a union on a memory "
           "region");
-      error = true;
-      return;
     }
+    return WalkResult::advance();
   });
-
-  if (error)
+  if (result.wasInterrupted())
     return None;
 
   int64_t totalSizeInBytes = 0;
@@ -969,17 +965,18 @@ void mlir::getSequentialLoops(
 bool mlir::isLoopParallel(AffineForOp forOp) {
   // Collect all load and store ops in loop nest rooted at 'forOp'.
   SmallVector<Operation *, 8> loadAndStoreOpInsts;
-  bool hasSideEffectingOps = false;
-  forOp.getOperation()->walk([&](Operation *opInst) {
+  auto walkResult = forOp.walk([&](Operation *opInst) {
     if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
-      return loadAndStoreOpInsts.push_back(opInst);
-    if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
-        !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect()) {
-      hasSideEffectingOps = true;
-    }
+      loadAndStoreOpInsts.push_back(opInst);
+    else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
+             !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect())
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
   });
+
   // Stop early if the loop has unknown ops with side effects.
-  if (hasSideEffectingOps)
+  if (walkResult.wasInterrupted())
     return false;
 
   // Dep check depth would be number of enclosing loops + 1.
index 4622098..ea2a6d6 100644 (file)
@@ -32,3 +32,20 @@ void detail::walkOperations(Operation *op,
 
   callback(op);
 }
+
+/// Walk all of the operations nested under and including the given operations.
+/// This methods walks operations until an interrupt signal is received.
+WalkResult
+detail::walkOperations(Operation *op,
+                       function_ref<WalkResult(Operation *op)> callback) {
+  // TODO(b/140235992) This walk should be iterative over the operations.
+  for (auto &region : op->getRegions()) {
+    for (auto &block : region) {
+      // Early increment here in the case where the operation is erased.
+      for (auto &nestedOp : llvm::make_early_inc_range(block))
+        if (walkOperations(&nestedOp, callback).wasInterrupted())
+          return WalkResult::interrupt();
+    }
+  }
+  return callback(op);
+}
index 99f315e..8f96cc2 100644 (file)
@@ -114,12 +114,12 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
        it != Block::reverse_iterator(opA); ++it) {
     Operation *opX = &(*it);
     opX->walk([&](Operation *op) {
-      if (lastDepOp)
-        return;
       if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
-        if (isDependentLoadOrStoreOp(op, values))
+        if (isDependentLoadOrStoreOp(op, values)) {
           lastDepOp = opX;
-        return;
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
       }
       for (auto *value : op->getResults()) {
         for (auto *user : value->getUsers()) {
@@ -128,9 +128,11 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
           getLoopIVs(*user, &loops);
           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
             lastDepOp = opX;
+            return WalkResult::interrupt();
           }
         }
       }
+      return WalkResult::advance();
     });
     if (lastDepOp)
       break;
@@ -257,15 +259,13 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
 /// returns false otherwise.
 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
-  bool ret = true;
-  forOpRoot.walk([&](AffineForOp forOp) {
+  auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
     auto *childForOp = forOp.getOperation();
     auto *parentForOp = forOp.getOperation()->getParentOp();
     if (!llvm::isa<FuncOp>(parentForOp)) {
       if (!isa<AffineForOp>(parentForOp)) {
         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
-        ret = false;
-        return;
+        return WalkResult::interrupt();
       }
       // Add mapping to 'forOp' from its parent AffineForOp.
       stats->loopMap[parentForOp].push_back(forOp);
@@ -279,18 +279,20 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
         ++count;
     }
     stats->opCountMap[childForOp] = count;
+
     // Record trip count for 'forOp'. Set flag if trip count is not
     // constant.
     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
     if (!maybeConstTripCount.hasValue()) {
       // Currently only constant trip count loop nests are supported.
       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
-      ret = false;
-      return;
+      return WalkResult::interrupt();
     }
+
     stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+    return WalkResult::advance();
   });
-  return ret;
+  return !walkResult.wasInterrupted();
 }
 
 // Computes the total cost of the loop nest rooted at 'forOp'.