/// Walk the operations in this block in postorder, calling the callback for
/// each operation.
/// See Operation::walk for more details.
- template <typename FnT> void walk(FnT &&callback) {
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ RetT walk(FnT &&callback) {
return walk(begin(), end(), std::forward<FnT>(callback));
}
/// postorder, calling the callback for each operation. This method is invoked
/// for void return callbacks.
/// See Operation::walk for more details.
- template <typename FnT>
- void walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
+ walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walkOperations(&op, std::forward<FnT>(callback));
+ detail::walkOperations(&op, callback);
+ }
+
+ /// Walk the operations in the specified [begin, end) range of this block in
+ /// postorder, calling the callback for each operation. This method is invoked
+ /// for interruptible callbacks.
+ /// See Operation::walk for more details.
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
+ walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+ for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
+ if (detail::walkOperations(&op, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
/// Walk the operation in postorder, calling the callback for each nested
/// operation(including this one).
/// See Operation::walk for more details.
- template <typename FnT> void walk(FnT &&callback) {
- state->walk(std::forward<FnT>(callback));
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ RetT walk(FnT &&callback) {
+ return state->walk(std::forward<FnT>(callback));
}
// These are default implementations of customization hooks.
/// Walk the operation in postorder, calling the callback for each nested
/// operation(including this one). The callback method can take any of the
/// following forms:
- /// (void)(Operation*) : Walk all operations opaquely.
+ /// void(Operation*) : Walk all operations opaquely.
/// * op->walk([](Operation *nestedOp) { ...});
- /// (void)(OpT) : Walk all operations of the given derived type.
+ /// void(OpT) : Walk all operations of the given derived type.
/// * op->walk([](ReturnOp returnOp) { ...});
- template <typename FnT> void walk(FnT &&callback) {
- detail::walkOperations(this, std::forward<FnT>(callback));
+ /// WalkResult(Operation*|OpT) : Walk operations, but allow for
+ /// interruption/cancellation.
+ /// * op->walk([](... op) {
+ /// // Interrupt, i.e cancel, the walk based on some invariant.
+ /// if (some_invariant)
+ /// return WalkResult::interrupt();
+ /// return WalkResult::advance();
+ /// });
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ RetT walk(FnT &&callback) {
+ return detail::walkOperations(this, std::forward<FnT>(callback));
}
//===--------------------------------------------------------------------===//
void dropAllReferences();
/// Walk the operations in this region in postorder, calling the callback for
- /// each operation.
+ /// each operation. This method is invoked for void-returning callbacks.
/// See Operation::walk for more details.
- template <typename FnT> void walk(FnT &&callback) {
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
+ walk(FnT &&callback) {
for (auto &block : *this)
block.walk(callback);
}
+ /// Walk the operations in this region in postorder, calling the callback for
+ /// each operation. This method is invoked for interruptible callbacks.
+ /// See Operation::walk for more details.
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
+ walk(FnT &&callback) {
+ for (auto &block : *this)
+ if (block.walk(callback).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ }
+
/// Displays the CFG in a window. This is for use from the debugger and
/// depends on Graphviz to generate the graph.
/// This function is defined in ViewRegionGraph and only works with that
#define MLIR_IR_VISITORS_H
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
+class Diagnostic;
+class InFlightDiagnostic;
class Operation;
+/// A utility result that is used to signal if a walk method should be
+/// interrupted or advance.
+class WalkResult {
+ enum ResultEnum { Interrupt, Advance } result;
+
+public:
+ WalkResult(ResultEnum result) : result(result) {}
+
+ /// Allow LogicalResult to interrupt the walk on failure.
+ WalkResult(LogicalResult result)
+ : result(failed(result) ? Interrupt : Advance) {}
+
+ /// Allow diagnostics to interrupt the walk.
+ WalkResult(Diagnostic &&) : result(Interrupt) {}
+ WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
+
+ bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
+
+ static WalkResult interrupt() { return {Interrupt}; }
+ static WalkResult advance() { return {Advance}; }
+
+ /// Returns if the walk was interrupted.
+ bool wasInterrupted() const { return result == Interrupt; }
+};
+
namespace detail {
/// Helper templates to deduce the first argument of a callback parameter.
template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
/// Walk all of the operations nested under and including the given operation.
void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
+/// Walk all of the operations nested under and including the given operation.
+/// This methods walks operations until an interrupt result is returned by the
+/// callback.
+WalkResult walkOperations(Operation *op,
+ function_ref<WalkResult(Operation *op)> callback);
+
// Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
-// methods. They are also templated to allow for dynamically dispatching based
+// methods. They are also templated to allow for statically dispatching based
// upon the type of the callback function.
/// Walk all of the operations nested under and including the given operation.
/// Walk all of the operations of type 'ArgT' nested under and including the
/// given operation. This method is selected for void returning callbacks that
-/// operation on a specific derived operation type.
+/// operate on a specific derived operation type.
///
/// Example:
/// op->walk([](ReturnOp op) { ... });
};
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
}
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for WalkReturn returning
+/// interruptible callbacks that operate on a specific derived operation type.
+///
+/// Example:
+/// op->walk([](ReturnOp op) {
+/// if (some_invariant)
+/// return WalkResult::interrupt();
+/// return WalkResult::advance();
+/// });
+template <
+ typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+ std::is_same<RetT, WalkResult>::value,
+ RetT>::type
+walkOperations(Operation *op, FuncTy &&callback) {
+ auto wrapperFn = [&](Operation *op) {
+ if (auto derivedOp = dyn_cast<ArgT>(op))
+ return callback(derivedOp);
+ return WalkResult::advance();
+ };
+ return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
+}
+
+/// Utility to provide the return type of a templated walk method.
+template <typename FnT>
+using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
} // end namespace detail
} // namespace mlir
SmallDenseMap<Value *, std::unique_ptr<MemRefRegion>, 4> regions;
// Walk this 'affine.for' operation to gather all memory regions.
- bool error = false;
- block.walk(start, end, [&](Operation *opInst) {
+ auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
// Neither load nor a store op.
- return;
+ return WalkResult::advance();
}
// Compute the memref region symbolic in any IVs enclosing this block.
if (failed(
region->compute(opInst,
/*loopDepth=*/getNestingDepth(*block.begin())))) {
- opInst->emitError("Error obtaining memory region\n");
- error = true;
- return;
+ return opInst->emitError("Error obtaining memory region\n");
}
+
auto it = regions.find(region->memref);
if (it == regions.end()) {
regions[region->memref] = std::move(region);
} else if (failed(it->second->unionBoundingBox(*region))) {
- opInst->emitWarning(
+ return opInst->emitWarning(
"getMemoryFootprintBytes: unable to perform a union on a memory "
"region");
- error = true;
- return;
}
+ return WalkResult::advance();
});
-
- if (error)
+ if (result.wasInterrupted())
return None;
int64_t totalSizeInBytes = 0;
bool mlir::isLoopParallel(AffineForOp forOp) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
- bool hasSideEffectingOps = false;
- forOp.getOperation()->walk([&](Operation *opInst) {
+ auto walkResult = forOp.walk([&](Operation *opInst) {
if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
- return loadAndStoreOpInsts.push_back(opInst);
- if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
- !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect()) {
- hasSideEffectingOps = true;
- }
+ loadAndStoreOpInsts.push_back(opInst);
+ else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
+ !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect())
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
});
+
// Stop early if the loop has unknown ops with side effects.
- if (hasSideEffectingOps)
+ if (walkResult.wasInterrupted())
return false;
// Dep check depth would be number of enclosing loops + 1.
callback(op);
}
+
+/// Walk all of the operations nested under and including the given operations.
+/// This methods walks operations until an interrupt signal is received.
+WalkResult
+detail::walkOperations(Operation *op,
+ function_ref<WalkResult(Operation *op)> callback) {
+ // TODO(b/140235992) This walk should be iterative over the operations.
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ // Early increment here in the case where the operation is erased.
+ for (auto &nestedOp : llvm::make_early_inc_range(block))
+ if (walkOperations(&nestedOp, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ }
+ return callback(op);
+}
it != Block::reverse_iterator(opA); ++it) {
Operation *opX = &(*it);
opX->walk([&](Operation *op) {
- if (lastDepOp)
- return;
if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
- if (isDependentLoadOrStoreOp(op, values))
+ if (isDependentLoadOrStoreOp(op, values)) {
lastDepOp = opX;
- return;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
}
for (auto *value : op->getResults()) {
for (auto *user : value->getUsers()) {
getLoopIVs(*user, &loops);
if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
lastDepOp = opX;
+ return WalkResult::interrupt();
}
}
}
+ return WalkResult::advance();
});
if (lastDepOp)
break;
/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
/// returns false otherwise.
bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
- bool ret = true;
- forOpRoot.walk([&](AffineForOp forOp) {
+ auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
auto *childForOp = forOp.getOperation();
auto *parentForOp = forOp.getOperation()->getParentOp();
if (!llvm::isa<FuncOp>(parentForOp)) {
if (!isa<AffineForOp>(parentForOp)) {
LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
- ret = false;
- return;
+ return WalkResult::interrupt();
}
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentForOp].push_back(forOp);
++count;
}
stats->opCountMap[childForOp] = count;
+
// Record trip count for 'forOp'. Set flag if trip count is not
// constant.
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
// Currently only constant trip count loop nests are supported.
LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
- ret = false;
- return;
+ return WalkResult::interrupt();
}
+
stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+ return WalkResult::advance();
});
- return ret;
+ return !walkResult.wasInterrupted();
}
// Computes the total cost of the loop nest rooted at 'forOp'.