[spirv] Fix function entry block erase after moving to spv.selection
authorLei Zhang <antiagainst@google.com>
Mon, 7 Oct 2019 19:36:39 +0000 (12:36 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 Oct 2019 19:37:13 +0000 (12:37 -0700)
The structured selection/loop's entry block does not have arguments.
If the function's header block is also part of the structured control
flow, we cannot just simply erase it because it may contain arguments
matching the function signature and used by the cloned blocks. Instead,
turn it into a block only containing a spv.Branch op.

Also, we can directly emit instructions for the spv.selection header
block to the block containing the spv.selection op. This eliminates
unnecessary branches in the SPIR-V blob.

Added a test for nested spv.loop.

PiperOrigin-RevId: 273351424

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/loop.mlir
mlir/test/Dialect/SPIRV/Serialization/selection.mlir

index 504e70a..3fbf5ee 100644 (file)
@@ -785,8 +785,9 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
     return failure();
   }
   if (opcode == spirv::Opcode::OpFunctionEnd) {
-    LLVM_DEBUG(llvm::dbgs() << "[fn] completed function " << fnName << " (type="
-                            << fnType << ", id=" << operands[1] << ")\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "[fn] completed function '" << fnName << "' (type=" << fnType
+               << ", id=" << operands[1] << ")\n");
     return processFunctionEnd(instOperands);
   }
   if (opcode != spirv::Opcode::OpLabel) {
@@ -813,7 +814,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
     return failure();
   }
 
-  LLVM_DEBUG(llvm::dbgs() << "[fn] completed function " << fnName << " (type="
+  LLVM_DEBUG(llvm::dbgs() << "[fn] completed function '" << fnName << "' (type="
                           << fnType << ", id=" << operands[1] << ")\n");
   return processFunctionEnd(instOperands);
 }
@@ -1641,8 +1642,8 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   Region &body = op->getRegion(0);
 
   BlockAndValueMapping mapper;
-  // All references to the old merge block should be directed to the loop
-  // merge block in the LoopOp's region.
+  // All references to the old merge block should be directed to the
+  // selection/loop merge block in the SelectionOp/LoopOp's region.
   mapper.map(mergeBlock, &body.back());
 
   collectBlocksInConstruct();
@@ -1673,13 +1674,12 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   // block in this loop construct.
   OpBuilder builder(body);
   for (auto *block : constructBlocks) {
-    assert(block->getNumArguments() == 0 &&
-           "block in loop construct should not have arguments");
-
-    // Create an block and insert it before the loop merge block in the
-    // LoopOp's region.
+    // Create a block and insert it before the selection/loop merge block in the
+    // SelectionOp/LoopOp's region.
     auto *newBlock = builder.createBlock(&body.back());
     mapper.map(block, newBlock);
+    LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
+                            << " from block " << block << "\n");
 
     for (auto &op : *block)
       newBlock->push_back(op.clone(mapper));
@@ -1714,15 +1714,38 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   }
 
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
-  // deleted.
+  // cleaned up.
+  LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n");
+  // First we need to drop all uses on ops inside all blocks. This is needed
+  // because we can have blocks referencing SSA values from one another.
   for (auto *block : constructBlocks)
-    block->erase();
+    block->dropAllReferences();
+
+  // Then erase all blocks except the old header block.
+  for (auto *block : constructBlocks) {
+    // The structured selection/loop's entry block does not have arguments.
+    // If the function's header block is also part of the structured control
+    // flow, we cannot just simply erase it because it may contain arguments
+    // matching the function signature and used by the cloned blocks.
+    if (block->isEntryBlock() && isa<FuncOp>(block->getParentOp())) {
+      LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
+                              << " to only contain a spv.Branch op\n");
+      // Still keep the function entry block for the potential block arguments,
+      // but replace all ops inside with a branch to the merge block.
+      block->clear();
+      builder.setInsertionPointToEnd(block);
+      builder.create<spirv::BranchOp>(location, mergeBlock);
+    } else {
+      LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n");
+      block->erase();
+    }
+  }
 
   return success();
 }
 
 LogicalResult Deserializer::structurizeControlFlow() {
-  LLVM_DEBUG(llvm::dbgs() << "[cf] structurizing control flow\n");
+  LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
 
   while (!blockMergeInfo.empty()) {
     auto *headerBlock = blockMergeInfo.begin()->first;
index 445f02c..f2314ff 100644 (file)
@@ -243,11 +243,12 @@ private:
   uint32_t assignBlockID(Block *block);
 
   // Processes the given `block` and emits SPIR-V instructions for all ops
-  // inside. `actionBeforeTerminator` is a callback that will be invoked before
-  // handling the terminator op. It can be used to inject the Op*Merge
-  // instruction if this is a SPIR-V selection/loop header block.
+  // inside. Does not emit OpLabel for this block if `omitLabel` is true.
+  // `actionBeforeTerminator` is a callback that will be invoked before handling
+  // the terminator op. It can be used to inject the Op*Merge instruction if
+  // this is a SPIR-V selection/loop header block.
   LogicalResult
-  processBlock(Block *block,
+  processBlock(Block *block, bool omitLabel = false,
                llvm::function_ref<void()> actionBeforeTerminator = nullptr);
 
   LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
@@ -1194,15 +1195,17 @@ uint32_t Serializer::assignBlockID(Block *block) {
 }
 
 LogicalResult
-Serializer::processBlock(Block *block,
+Serializer::processBlock(Block *block, bool omitLabel,
                          llvm::function_ref<void()> actionBeforeTerminator) {
-  auto blockID = findBlockID(block);
-  if (blockID == 0) {
-    blockID = assignBlockID(block);
-  }
+  if (!omitLabel) {
+    auto blockID = findBlockID(block);
+    if (blockID == 0) {
+      blockID = assignBlockID(block);
+    }
 
-  // Emit OpLabel for this block.
-  encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID});
+    // Emit OpLabel for this block.
+    encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID});
+  }
 
   // Process each op in this block except the terminator.
   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
@@ -1294,16 +1297,8 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
 
   auto *headerBlock = selectionOp.getHeaderBlock();
   auto *mergeBlock = selectionOp.getMergeBlock();
-  auto headerID = findBlockID(headerBlock);
   auto mergeID = findBlockID(mergeBlock);
 
-  // This selection is in some MLIR block with preceding and following ops. In
-  // the binary format, it should reside in separate SPIR-V blocks from its
-  // preceding and following ops. So we need to emit unconditional branches to
-  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
-  // afterwards.
-  encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
-
   // Emit the selection header block, which dominates all other blocks, first.
   // We need to emit an OpSelectionMerge instruction before the loop header
   // block's terminator.
@@ -1313,7 +1308,13 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
         functions, spirv::Opcode::OpSelectionMerge,
         {mergeID, static_cast<uint32_t>(spirv::LoopControl::None)});
   };
-  if (failed(processBlock(headerBlock, emitSelectionMerge)))
+  // For structured selection, we cannot have blocks in the selection construct
+  // branching to the selection header block. Entering the selection (and
+  // reaching the selection header) must be from the block containing the
+  // spv.selection op. If there are ops ahead of the spv.selection op in the
+  // block, we can "merge" them into the selection header. So here we don't need
+  // to emit a separate block; just continue with the existing block.
+  if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge)))
     return failure();
 
   // Process all blocks with a depth-first visitor starting from the header
@@ -1363,7 +1364,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
         functions, spirv::Opcode::OpLoopMerge,
         {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)});
   };
-  if (failed(processBlock(headerBlock, emitLoopMerge)))
+  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
     return failure();
 
   // Process all blocks with a depth-first visitor starting from the header
index dcfed79..9254d2c 100644 (file)
@@ -1,4 +1,6 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+// Single loop
 
 spv.module "Logical" "GLSL450" {
   // for (int i = 0; i < count; ++i) {}
@@ -57,3 +59,109 @@ spv.module "Logical" "GLSL450" {
   capabilities = ["Shader"]
 }
 
+// -----
+
+// Nested loop
+
+spv.module "Logical" "GLSL450" {
+  // for (int i = 0; i < count; ++i) {
+  //   for (int j = 0; j < count; ++j) { }
+  // }
+  func @loop(%count : i32) -> () {
+    %zero = spv.constant 0: i32
+    %one = spv.constant 1: i32
+    %ivar = spv.Variable init(%zero) : !spv.ptr<i32, Function>
+    %jvar = spv.Variable init(%zero) : !spv.ptr<i32, Function>
+
+// CHECK:        spv.Branch ^bb1
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   spv.loop
+    spv.loop {
+// CHECK-NEXT:     spv.Branch ^bb1
+      spv.Branch ^header
+
+// CHECK-NEXT:   ^bb1:
+    ^header:
+// CHECK-NEXT:     spv.Load
+      %ival0 = spv.Load "Function" %ivar : i32
+// CHECK-NEXT:     spv.SLessThan
+      %icmp = spv.SLessThan %ival0, %count : i32
+// CHECK-NEXT:     spv.BranchConditional %{{.*}}, ^bb2, ^bb5
+      spv.BranchConditional %icmp, ^body, ^merge
+
+// CHECK-NEXT:   ^bb2:
+    ^body:
+// CHECK-NEXT:     spv.constant 0
+// CHECK-NEXT:                  spv.Store
+      spv.Store "Function" %jvar, %zero : i32
+// CHECK-NEXT:     spv.Branch ^bb3
+// CHECK-NEXT:   ^bb3:
+// CHECK-NEXT:     spv.loop {
+      spv.loop {
+// CHECK-NEXT:       spv.Branch ^bb1
+        spv.Branch ^header
+
+// CHECK-NEXT:     ^bb1:
+      ^header:
+// CHECK-NEXT:       spv.Load
+        %jval0 = spv.Load "Function" %jvar : i32
+// CHECK-NEXT:       spv.SLessThan
+        %jcmp = spv.SLessThan %jval0, %count : i32
+// CHECK-NEXT:       spv.BranchConditional %{{.*}}, ^bb2, ^bb4
+        spv.BranchConditional %jcmp, ^body, ^merge
+
+// CHECK-NEXT:     ^bb2:
+      ^body:
+        // Do nothing
+// CHECK-NEXT:       spv.Branch ^bb3
+        spv.Branch ^continue
+
+// CHECK-NEXT:     ^bb3:
+      ^continue:
+// CHECK-NEXT:       spv.Load
+        %jval1 = spv.Load "Function" %jvar : i32
+// CHECK-NEXT:       spv.constant 1
+// CHECK-NEXT:       spv.IAdd
+        %add = spv.IAdd %jval1, %one : i32
+// CHECK-NEXT:       spv.Store
+        spv.Store "Function" %jvar, %add : i32
+// CHECK-NEXT:       spv.Branch ^bb1
+        spv.Branch ^header
+
+// CHECK-NEXT:     ^bb4:
+      ^merge:
+// CHECK-NEXT:       spv._merge
+        spv._merge
+      } // end inner loop
+
+// CHECK:          spv.Branch ^bb4
+      spv.Branch ^continue
+
+// CHECK-NEXT:   ^bb4:
+    ^continue:
+// CHECK-NEXT:     spv.Load
+      %ival1 = spv.Load "Function" %ivar : i32
+// CHECK-NEXT:     spv.constant 1
+// CHECK-NEXT:     spv.IAdd
+      %add = spv.IAdd %ival1, %one : i32
+// CHECK-NEXT:     spv.Store
+      spv.Store "Function" %ivar, %add : i32
+// CHECK-NEXT:     spv.Branch ^bb1
+      spv.Branch ^header
+
+// CHECK-NEXT:   ^bb5:
+// CHECK-NEXT:     spv._merge
+    ^merge:
+      spv._merge
+    } // end outer loop
+    spv.Return
+  }
+
+  func @main() -> () {
+    spv.Return
+  }
+  spv.EntryPoint "GLCompute" @main
+} attributes {
+  capabilities = ["Shader"]
+}
+
index 7aa1691..31797e7 100644 (file)
@@ -1,15 +1,19 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+// Selection with both then and else branches
 
 spv.module "Logical" "GLSL450" {
   func @selection(%cond: i1) -> () {
+// CHECK:        spv.Branch ^bb1
+// CHECK-NEXT: ^bb1:
     %zero = spv.constant 0: i32
     %one = spv.constant 1: i32
     %two = spv.constant 2: i32
     %var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
 
-// CHECK:        spv.Branch ^bb1
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT:   spv.selection
+// CHECK-NEXT:   spv.selection {
+// CHECK-NEXT:     spv.constant 0
+// CHECK-NEXT:     spv.Variable
     spv.selection {
 // CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
       spv.BranchConditional %cond, ^then, ^else
@@ -47,3 +51,43 @@ spv.module "Logical" "GLSL450" {
 } attributes {
   capabilities = ["Shader"]
 }
+
+// -----
+
+// Selection with only then branch
+// Selection in function entry block
+
+spv.module "Logical" "GLSL450" {
+// CHECK:      func @selection(%[[ARG:.*]]: i1
+  func @selection(%cond: i1) -> (i32) {
+// CHECK:        spv.Branch ^bb1
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   spv.selection
+    spv.selection {
+// CHECK-NEXT: spv.BranchConditional %[[ARG]], ^bb1, ^bb2
+      spv.BranchConditional %cond, ^then, ^merge
+
+// CHECK:        ^bb1:
+    ^then:
+      %zero = spv.constant 0 : i32
+      spv.ReturnValue  %zero : i32
+
+// CHECK:        ^bb2:
+    ^merge:
+// CHECK-NEXT:     spv._merge
+      spv._merge
+    }
+
+    %one = spv.constant 1 : i32
+    spv.ReturnValue  %one : i32
+  }
+
+  func @main() -> () {
+    spv.Return
+  }
+  spv.EntryPoint "GLCompute" @main
+  spv.ExecutionMode @main "LocalSize", 1, 1, 1
+} attributes {
+  capabilities = ["Shader"]
+}
+