[spirv] Use LLVM graph traversal utility for PrettyBlockOrderVisitor
authorLei Zhang <antiagainst@google.com>
Tue, 29 Oct 2019 14:03:26 +0000 (07:03 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 Oct 2019 14:04:00 +0000 (07:04 -0700)
This removes a bunch of special tailored DFS code in favor of the common
LLVM utility. Besides, we avoid recursion with system stack given that
llvm::depth_first_ext is iterator based and maintains its own stack.
PiperOrigin-RevId: 277272961

mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp

index afb9e0b..f92b9ae 100644 (file)
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -52,76 +54,35 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
   return success();
 }
 
-namespace {
-/// A pre-order depth-first vistor for processing basic blocks.
+/// A pre-order depth-first visitor function for processing basic blocks.
 ///
-/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
-/// of blocks in a function must satisfy the rule that blocks appear before all
-/// blocks they dominate." This can be achieved by a pre-order CFG traversal
-/// algorithm. To make the serialization output more logical and readable to
-/// human, we perform depth-first CFG traversal and delay the serialization of
-/// the merge block and the continue block, if exists, until after all other
-/// blocks have been processed.
+/// Visits the basic blocks starting from the given `headerBlock` in pre-order
+/// depth-first manner and calls `blockHandler` on each block. Skips handling
+/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
+/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
+/// successors.
 ///
-/// This visitor is special tailored for SPIR-V functions, spv.selection or
-/// spv.loop block serialization to satisfy SPIR-V validation rules. It should
-/// not be used as a general depth-first block visitor.
-class PrettyBlockOrderVisitor {
-public:
-  using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>;
-
-  /// Visits the basic blocks starting from the given `headerBlock`'s successors
-  /// in pre-order depth-first manner and calls `blockHandler` on each block.
-  /// Skips handling blocks in the `skipBlocks` list. If `headerBlock` is also
-  /// in `skipBlocks` list, still handles all its successors.
-  static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler,
-                             ArrayRef<Block *> skipBlocks = {}) {
-    return PrettyBlockOrderVisitor(blockHandler, skipBlocks)
-        .visitHeaderBlock(headerBlock);
-  }
-
-private:
-  PrettyBlockOrderVisitor(BlockHandlerType blockHandler,
-                          ArrayRef<Block *> skipBlocks)
-      : blockHandler(blockHandler),
-        doneBlocks(skipBlocks.begin(), skipBlocks.end()) {}
-
-  LogicalResult visitHeaderBlock(Block *header) {
-    // Skip processing the header block if requested.
-    if (!llvm::is_contained(doneBlocks, header)) {
-      if (failed(blockHandler(header)))
-        return failure();
-      doneBlocks.insert(header);
-    }
-
-    for (auto *successor : header->getSuccessors()) {
-      if (failed(visitNormalBlock(successor)))
-        return failure();
-    }
-
-    return success();
-  }
-
-  LogicalResult visitNormalBlock(Block *block) {
-    if (doneBlocks.count(block))
-      return success();
-
+/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
+/// of blocks in a function must satisfy the rule that blocks appear before
+/// all blocks they dominate." This can be achieved by a pre-order CFG
+/// traversal algorithm. To make the serialization output more logical and
+/// readable to human, we perform depth-first CFG traversal and delay the
+/// serialization of the merge block and the continue block, if exists, until
+/// after all other blocks have been processed.
+static LogicalResult visitInPrettyBlockOrder(
+    Block *headerBlock, llvm::function_ref<LogicalResult(Block *)> blockHandler,
+    bool skipHeader = false, ArrayRef<Block *> skipBlocks = {}) {
+  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
+  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
+
+  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
+    if (skipHeader && block == headerBlock)
+      continue;
     if (failed(blockHandler(block)))
       return failure();
-    doneBlocks.insert(block);
-
-    for (auto *successor : block->getSuccessors()) {
-      if (failed(visitNormalBlock(successor)))
-        return failure();
-    }
-
-    return success();
   }
-
-  BlockHandlerType blockHandler;
-  SmallPtrSet<Block *, 4> doneBlocks;
-};
-} // namespace
+  return success();
+}
 
 namespace {
 
@@ -757,7 +718,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
     return op.emitError("external function is unhandled");
   }
 
-  if (failed(PrettyBlockOrderVisitor::visit(
+  if (failed(visitInPrettyBlockOrder(
           &op.front(), [&](Block *block) { return processBlock(block); })))
     return failure();
 
@@ -1519,9 +1480,9 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   // Process all blocks with a depth-first visitor starting from the header
   // block. The selection header block and merge block are skipped by this
   // visitor.
-  auto handleBlock = [&](Block *block) { return processBlock(block); };
-  if (failed(PrettyBlockOrderVisitor::visit(headerBlock, handleBlock,
-                                            {headerBlock, mergeBlock})))
+  if (failed(visitInPrettyBlockOrder(
+          headerBlock, [&](Block *block) { return processBlock(block); },
+          /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
     return failure();
 
   // There is nothing to do for the merge block in the selection, which just
@@ -1569,9 +1530,9 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   // Process all blocks with a depth-first visitor starting from the header
   // block. The loop header block, loop continue block, and loop merge block are
   // skipped by this visitor and handled later in this function.
-  auto handleBlock = [&](Block *block) { return processBlock(block); };
-  if (failed(PrettyBlockOrderVisitor::visit(
-          headerBlock, handleBlock, {headerBlock, continueBlock, mergeBlock})))
+  if (failed(visitInPrettyBlockOrder(
+          headerBlock, [&](Block *block) { return processBlock(block); },
+          /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
     return failure();
 
   // We have handled all other blocks. Now get to the loop continue block.