Initial implementation of merge return pass.
authorAlan Baker <alanbaker@google.com>
Wed, 8 Nov 2017 21:22:10 +0000 (16:22 -0500)
committerDavid Neto <dneto@google.com>
Wed, 15 Nov 2017 15:27:04 +0000 (10:27 -0500)
Works with current DefUseManager infrastructure.

Added merge return to the standard opts.

Added validation to passes.

Disabled pass for shader capabilty.

Android.mk
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/merge_return_pass.cpp [new file with mode: 0644]
source/opt/merge_return_pass.h [new file with mode: 0644]
source/opt/optimizer.cpp
source/opt/passes.h
test/opt/CMakeLists.txt
test/opt/pass_merge_return_test.cpp [new file with mode: 0644]
tools/opt/opt.cpp

index a6456f8..afc1514 100644 (file)
@@ -89,7 +89,8 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/strip_debug_info_pass.cpp \
                source/opt/type_manager.cpp \
                source/opt/types.cpp \
-               source/opt/unify_const_pass.cpp
+               source/opt/unify_const_pass.cpp \
+               source/opt/merge_return_pass.cpp
 
 # Locations of grammar files.
 SPV_CORE10_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.0/spirv.core.grammar.json
index 8d9f94d..b3d34b4 100644 (file)
@@ -406,6 +406,18 @@ Optimizer::PassToken CreateCFGCleanupPass();
 // that are not referenced.
 Optimizer::PassToken CreateDeadVariableEliminationPass();
 
+// Create merge return pass.
+// This pass replaces all returns with unconditional branches to a new block
+// containing a return. If necessary, this new block will contain a PHI node to
+// select the correct return value.
+//
+// This pass does not consider unreachable code, nor does it perform any other
+// optimizations.
+//
+// This pass does not currently support structured control flow. It bails out if
+// the shader capability is detected.
+Optimizer::PassToken CreateMergeReturnPass();
+
 }  // namespace spvtools
 
 #endif  // SPIRV_TOOLS_OPTIMIZER_HPP_
index ff8583e..45b71ec 100644 (file)
@@ -43,6 +43,7 @@ add_library(SPIRV-Tools-opt
   local_single_store_elim_pass.h
   local_ssa_elim_pass.h
   log.h
+  merge_return_pass.h
   module.h
   null_pass.h
   reflect.h
@@ -88,6 +89,7 @@ add_library(SPIRV-Tools-opt
   local_single_block_elim_pass.cpp
   local_single_store_elim_pass.cpp
   local_ssa_elim_pass.cpp
+  merge_return_pass.cpp
   module.cpp
   eliminate_dead_functions_pass.cpp
   remove_duplicates_pass.cpp
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
new file mode 100644 (file)
index 0000000..9374a91
--- /dev/null
@@ -0,0 +1,120 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "merge_return_pass.h"
+
+#include "instruction.h"
+#include "ir_context.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status MergeReturnPass::Process(ir::IRContext* irContext) {
+  InitializeProcessing(irContext);
+
+  // TODO (alanbaker): Support structured control flow. Bail out in the
+  // meantime.
+  if (get_module()->HasCapability(SpvCapabilityShader))
+    return Status::SuccessWithoutChange;
+
+  bool modified = false;
+  for (auto& function : *get_module()) {
+    std::vector<ir::BasicBlock*> returnBlocks = CollectReturnBlocks(&function);
+    modified |= MergeReturnBlocks(&function, returnBlocks);
+  }
+
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+std::vector<ir::BasicBlock*> MergeReturnPass::CollectReturnBlocks(
+    ir::Function* function) {
+  std::vector<ir::BasicBlock*> returnBlocks;
+  for (auto& block : *function) {
+    ir::Instruction& terminator = *block.tail();
+    if (terminator.opcode() == SpvOpReturn ||
+        terminator.opcode() == SpvOpReturnValue) {
+      returnBlocks.push_back(&block);
+    }
+  }
+
+  return returnBlocks;
+}
+
+bool MergeReturnPass::MergeReturnBlocks(
+    ir::Function* function, const std::vector<ir::BasicBlock*>& returnBlocks) {
+  if (returnBlocks.size() <= 1) {
+    // No work to do.
+    return false;
+  }
+
+  // Create a label for the new return block
+  std::unique_ptr<ir::Instruction> returnLabel(
+      new ir::Instruction(SpvOpLabel, 0u, TakeNextId(), {}));
+  uint32_t returnId = returnLabel->result_id();
+
+  // Create the new basic block
+  std::unique_ptr<ir::BasicBlock> returnBlock(
+      new ir::BasicBlock(std::move(returnLabel)));
+  function->AddBasicBlock(std::move(returnBlock));
+  ir::Function::iterator retBlockIter = --function->end();
+
+  // Create the PHI for the merged block (if necessary)
+  // Create new return
+  std::vector<ir::Operand> phiOps;
+  for (auto block : returnBlocks) {
+    if (block->tail()->opcode() == SpvOpReturnValue) {
+      phiOps.push_back(
+          {SPV_OPERAND_TYPE_ID, {block->tail()->GetSingleWordInOperand(0u)}});
+      phiOps.push_back({SPV_OPERAND_TYPE_ID, {block->id()}});
+    }
+  }
+
+  if (!phiOps.empty()) {
+    // Need a PHI node to select the correct return value.
+    uint32_t phiResultId = TakeNextId();
+    uint32_t phiTypeId = function->type_id();
+    std::unique_ptr<ir::Instruction> phiInst(
+        new ir::Instruction(SpvOpPhi, phiTypeId, phiResultId, phiOps));
+    retBlockIter->AddInstruction(std::move(phiInst));
+    ir::BasicBlock::iterator phiIter = retBlockIter->tail();
+
+    std::unique_ptr<ir::Instruction> returnInst(new ir::Instruction(
+        SpvOpReturnValue, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {phiResultId}}}));
+    retBlockIter->AddInstruction(std::move(returnInst));
+    ir::BasicBlock::iterator ret = retBlockIter->tail();
+
+    get_def_use_mgr()->AnalyzeInstDefUse(&*phiIter);
+    get_def_use_mgr()->AnalyzeInstDef(&*ret);
+  } else {
+    std::unique_ptr<ir::Instruction> returnInst(
+        new ir::Instruction(SpvOpReturn));
+    retBlockIter->AddInstruction(std::move(returnInst));
+  }
+
+  // Replace returns with branches
+  for (auto block : returnBlocks) {
+    context()->KillInst(&*block->tail());
+    block->tail()->SetOpcode(SpvOpBranch);
+    block->tail()->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {returnId}}});
+    get_def_use_mgr()->AnalyzeInstUse(&*block->tail());
+    get_def_use_mgr()->AnalyzeInstUse(block->GetLabelInst());
+  }
+
+  get_def_use_mgr()->AnalyzeInstDefUse(retBlockIter->GetLabelInst());
+
+  return true;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
new file mode 100644 (file)
index 0000000..8cccd96
--- /dev/null
@@ -0,0 +1,55 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_MERGE_RETURN_PASS_H_
+#define LIBSPIRV_OPT_MERGE_RETURN_PASS_H_
+
+#include "basic_block.h"
+#include "function.h"
+#include "pass.h"
+
+#include <vector>
+
+namespace spvtools {
+namespace opt {
+
+// Documented in optimizer.hpp
+class MergeReturnPass : public Pass {
+ public:
+  MergeReturnPass() = default;
+  const char* name() const override { return "merge-return-pass"; }
+  Status Process(ir::IRContext*) override;
+
+  ir::IRContext::Analysis GetPreservedAnalyses() override {
+    return ir::IRContext::kAnalysisDefUse;
+  }
+
+ private:
+  // Returns all BasicBlocks terminated by OpReturn or OpReturnValue in
+  // |function|.
+  std::vector<ir::BasicBlock*> CollectReturnBlocks(ir::Function* function);
+
+  // Returns |true| if returns were merged, |false| otherwise.
+  //
+  // Creates a new basic block with a single return. If |function| returns a
+  // value, a phi node is created to select the correct value to return.
+  // Replaces old returns with an unconditional branch to the new block.
+  bool MergeReturnBlocks(ir::Function* function,
+                         const std::vector<ir::BasicBlock*>& returnBlocks);
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_MERGE_RETURN_PASS_H_
index 53692d4..2fe8410 100644 (file)
@@ -67,7 +67,8 @@ Optimizer& Optimizer::RegisterPass(PassToken&& p) {
 }
 
 Optimizer& Optimizer::RegisterPerformancePasses() {
-  return RegisterPass(CreateInlineExhaustivePass())
+  return RegisterPass(CreateMergeReturnPass())
+      .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
@@ -83,7 +84,8 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
 }
 
 Optimizer& Optimizer::RegisterSizePasses() {
-  return RegisterPass(CreateInlineExhaustivePass())
+  return RegisterPass(CreateMergeReturnPass())
+      .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
@@ -240,6 +242,11 @@ Optimizer::PassToken CreateCompactIdsPass() {
       MakeUnique<opt::CompactIdsPass>());
 }
 
+Optimizer::PassToken CreateMergeReturnPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::MergeReturnPass>());
+}
+
 std::vector<const char*> Optimizer::GetPassNames() const {
   std::vector<const char*> v;
   for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); i++) {
index ef01b48..3121f9d 100644 (file)
@@ -41,5 +41,7 @@
 #include "strength_reduction_pass.h"
 #include "strip_debug_info_pass.h"
 #include "unify_const_pass.h"
+#include "eliminate_dead_functions_pass.h"
+#include "merge_return_pass.h"
 
 #endif  // LIBSPIRV_OPT_PASSES_H_
index ecff063..159a011 100644 (file)
@@ -204,3 +204,7 @@ add_spvtools_unittest(TARGET ir_context
   LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET pass_merge_return
+  SRCS pass_merge_return_test.cpp pass_utils.cpp
+  LIBS SPIRV-Tools-opt
+)
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
new file mode 100644 (file)
index 0000000..ff2840c
--- /dev/null
@@ -0,0 +1,291 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include "spirv-tools/libspirv.hpp"
+#include "spirv-tools/optimizer.hpp"
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using MergeReturnPassTest = PassTest<::testing::Test>;
+
+TEST_F(MergeReturnPassTest, OneReturn) {
+  const std::string before =
+      R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %1 "simple_kernel"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string after = before;
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+TEST_F(MergeReturnPassTest, TwoReturnsNoValue) {
+  const std::string before =
+      R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %6 "simple_kernel"
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantFalse %3
+%1 = OpTypeFunction %2
+%6 = OpFunction %2 None %1
+%7 = OpLabel
+OpBranchConditional %4 %8 %9
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %6 "simple_kernel"
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantFalse %3
+%1 = OpTypeFunction %2
+%6 = OpFunction %2 None %1
+%7 = OpLabel
+OpBranchConditional %4 %8 %9
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+TEST_F(MergeReturnPassTest, TwoReturnsWithValues) {
+  const std::string before =
+      R"(OpCapability Linkage
+OpCapability Kernel
+OpMemoryModel Logical OpenCL
+%1 = OpTypeInt 32 0
+%2 = OpTypeBool
+%3 = OpConstantFalse %2
+%4 = OpConstant %1 0
+%5 = OpConstant %1 1
+%6 = OpTypeFunction %1
+%7 = OpFunction %1 None %6
+%8 = OpLabel
+OpBranchConditional %3 %9 %10
+%9 = OpLabel
+OpReturnValue %4
+%10 = OpLabel
+OpReturnValue %5
+OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability Linkage
+OpCapability Kernel
+OpMemoryModel Logical OpenCL
+%1 = OpTypeInt 32 0
+%2 = OpTypeBool
+%3 = OpConstantFalse %2
+%4 = OpConstant %1 0
+%5 = OpConstant %1 1
+%6 = OpTypeFunction %1
+%7 = OpFunction %1 None %6
+%8 = OpLabel
+OpBranchConditional %3 %9 %10
+%9 = OpLabel
+OpBranch %11
+%10 = OpLabel
+OpBranch %11
+%11 = OpLabel
+%12 = OpPhi %1 %4 %9 %5 %10
+OpReturnValue %12
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+TEST_F(MergeReturnPassTest, UnreachableReturnsNoValue) {
+  const std::string before =
+      R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %6 "simple_kernel"
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantFalse %3
+%1 = OpTypeFunction %2
+%6 = OpFunction %2 None %1
+%7 = OpLabel
+OpReturn
+%8 = OpLabel
+OpBranchConditional %4 %9 %10
+%9 = OpLabel
+OpReturn
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %6 "simple_kernel"
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantFalse %3
+%1 = OpTypeFunction %2
+%6 = OpFunction %2 None %1
+%7 = OpLabel
+OpBranch %11
+%8 = OpLabel
+OpBranchConditional %4 %9 %10
+%9 = OpLabel
+OpBranch %11
+%10 = OpLabel
+OpBranch %11
+%11 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+TEST_F(MergeReturnPassTest, UnreachableReturnsWithValues) {
+  const std::string before =
+      R"(OpCapability Linkage
+OpCapability Kernel
+OpMemoryModel Logical OpenCL
+%1 = OpTypeInt 32 0
+%2 = OpTypeBool
+%3 = OpConstantFalse %2
+%4 = OpConstant %1 0
+%5 = OpConstant %1 1
+%6 = OpTypeFunction %1
+%7 = OpFunction %1 None %6
+%8 = OpLabel
+%9 = OpIAdd %1 %4 %5
+OpReturnValue %9
+%10 = OpLabel
+OpBranchConditional %3 %11 %12
+%11 = OpLabel
+OpReturnValue %4
+%12 = OpLabel
+OpReturnValue %5
+OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability Linkage
+OpCapability Kernel
+OpMemoryModel Logical OpenCL
+%1 = OpTypeInt 32 0
+%2 = OpTypeBool
+%3 = OpConstantFalse %2
+%4 = OpConstant %1 0
+%5 = OpConstant %1 1
+%6 = OpTypeFunction %1
+%7 = OpFunction %1 None %6
+%8 = OpLabel
+%9 = OpIAdd %1 %4 %5
+OpBranch %13
+%10 = OpLabel
+OpBranchConditional %3 %11 %12
+%11 = OpLabel
+OpBranch %13
+%12 = OpLabel
+OpBranch %13
+%13 = OpLabel
+%14 = OpPhi %1 %9 %8 %4 %11 %5 %12
+OpReturnValue %14
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+TEST_F(MergeReturnPassTest, StructuredControlFlowNOP) {
+  const std::string before =
+      R"(OpCapability Addresses
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %6 "simple_shader"
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantFalse %3
+%1 = OpTypeFunction %2
+%6 = OpFunction %2 None %1
+%7 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %4 %8 %9
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpReturn
+%10 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+  const std::string after = before;
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::MergeReturnPass>(before, after, false, true);
+}
+
+}  // anonymous namespace
index 4236c93..c8d8575 100644 (file)
@@ -152,6 +152,12 @@ Options:
                Join two blocks into a single block if the second has the
                first as its only predecessor. Performed only on entry point
                call tree functions.
+  --merge-return
+               Replace all return instructions with unconditional branches to
+               a new basic block containing an unified return.
+
+               This pass does not currently support structured control flow. It
+               makes no changes if the shader capability is detected.
   --strength-reduction
                Replaces instructions with equivalent and less expensive ones.
   --eliminate-dead-variables
@@ -352,6 +358,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         optimizer->RegisterPass(CreateLocalSingleStoreElimPass());
       } else if (0 == strcmp(cur_arg, "--merge-blocks")) {
         optimizer->RegisterPass(CreateBlockMergePass());
+      } else if (0 == strcmp(cur_arg, "--merge-return")) {
+        optimizer->RegisterPass(CreateMergeReturnPass());
       } else if (0 == strcmp(cur_arg, "--eliminate-dead-branches")) {
         optimizer->RegisterPass(CreateDeadBranchElimPass());
       } else if (0 == strcmp(cur_arg, "--eliminate-dead-functions")) {