Create the dead function elimination pass
authorSteven Perron <stevenperron@google.com>
Tue, 19 Sep 2017 14:12:13 +0000 (10:12 -0400)
committerDavid Neto <dneto@google.com>
Tue, 26 Sep 2017 15:18:06 +0000 (11:18 -0400)
Creates a pass called eliminate dead functions that looks for functions
that could never be called, and deletes them from the module.

To support this change a new function was added to the Pass class to
traverse the call trees from diffent starting points.

Includes a test to ensure that annotations are removed when deleting a
dead function.  They were not, so fixed that up as well.

Did some cleanup of the assembly for the test in pass_test.cpp.  Trying
to make them smaller and easier to read.

12 files changed:
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/eliminate_dead_functions_pass.cpp [new file with mode: 0644]
source/opt/eliminate_dead_functions_pass.h [new file with mode: 0644]
source/opt/optimizer.cpp
source/opt/pass.cpp
source/opt/pass.h
source/opt/passes.h
test/opt/CMakeLists.txt
test/opt/eliminate_dead_functions_test.cpp [new file with mode: 0644]
test/opt/pass_test.cpp [new file with mode: 0644]
tools/opt/opt.cpp

index d438eebac0fb7fa4c9ea145223ad7f3bfb05749e..4464d3b940a043962c2040b02efdd1d4b8702153 100644 (file)
@@ -102,6 +102,12 @@ Optimizer::PassToken CreateNullPass();
 // Section 3.32.2 of the SPIR-V spec) of the SPIR-V module to be optimized.
 Optimizer::PassToken CreateStripDebugInfoPass();
 
+// Creates an eliminate-dead-functions pass.
+// An eliminate-dead-functions pass will remove all functions that are not in the
+// call trees rooted at entry points and exported functions.  These functions
+// are not needed because they will never be called.
+Optimizer::PassToken CreateEliminateDeadFunctionsPass();
+
 // Creates a set-spec-constant-default-value pass from a mapping from spec-ids
 // to the default values in the form of string.
 // A set-spec-constant-default-value pass sets the default values for the
@@ -204,7 +210,7 @@ Optimizer::PassToken CreateStrengthReductionPass();
 // this time it does not guarantee all such sequences are eliminated.
 //
 // Presence of phi instructions can inhibit this optimization. Handling
-// these is left for future improvements. 
+// these is left for future improvements.
 Optimizer::PassToken CreateBlockMergePass();
 
 // Creates an exhaustive inline pass.
@@ -215,7 +221,7 @@ Optimizer::PassToken CreateBlockMergePass();
 // there is no attempt to optimize for size or runtime performance. Functions
 // that are not in the call tree of an entry point are not changed.
 Optimizer::PassToken CreateInlineExhaustivePass();
-  
+
 // Creates an opaque inline pass.
 // An opaque inline pass inlines all function calls in all functions in all
 // entry point call trees where the called function contains an opaque type
@@ -226,9 +232,9 @@ Optimizer::PassToken CreateInlineExhaustivePass();
 // not legal in Vulkan. Functions that are not in the call tree of an entry
 // point are not changed.
 Optimizer::PassToken CreateInlineOpaquePass();
-  
+
 // Creates a single-block local variable load/store elimination pass.
-// For every entry point function, do single block memory optimization of 
+// For every entry point function, do single block memory optimization of
 // function variables referenced only with non-access-chain loads and stores.
 // For each targeted variable load, if previous store to that variable in the
 // block, replace the load's result id with the value id of the store.
@@ -240,9 +246,9 @@ Optimizer::PassToken CreateInlineOpaquePass();
 // The presence of access chain references and function calls can inhibit
 // the above optimization.
 //
-// Only modules with logical addressing are currently processed. 
+// Only modules with logical addressing are currently processed.
 //
-// This pass is most effective if preceeded by Inlining and 
+// This pass is most effective if preceeded by Inlining and
 // LocalAccessChainConvert. This pass will reduce the work needed to be done
 // by LocalSingleStoreElim and LocalMultiStoreElim.
 //
@@ -266,7 +272,7 @@ Optimizer::PassToken CreateDeadBranchElimPass();
 // Creates an SSA local variable load/store elimination pass.
 // For every entry point function, eliminate all loads and stores of function
 // scope variables only referenced with non-access-chain loads and stores.
-// Eliminate the variables as well. 
+// Eliminate the variables as well.
 //
 // The presence of access chain references and function calls can inhibit
 // the above optimization.
@@ -275,7 +281,7 @@ Optimizer::PassToken CreateDeadBranchElimPass();
 // Currently modules with any extensions enabled are not processed. This
 // is left for future work.
 //
-// This pass is most effective if preceeded by Inlining and 
+// This pass is most effective if preceeded by Inlining and
 // LocalAccessChainConvert. LocalSingleStoreElim and LocalSingleBlockElim
 // will reduce the work that this pass has to do.
 Optimizer::PassToken CreateLocalMultiStoreElimPass();
@@ -320,7 +326,7 @@ Optimizer::PassToken CreateLocalAccessChainConvertPass();
 Optimizer::PassToken CreateAggressiveDCEPass();
 
 // Creates a local single store elimination pass.
-// For each entry point function, this pass eliminates loads and stores for 
+// For each entry point function, this pass eliminates loads and stores for
 // function scope variable that are stored to only once, where possible. Only
 // whole variable loads and stores are eliminated; access-chain references are
 // not optimized. Replace all loads of such variables with the value that is
@@ -362,7 +368,7 @@ Optimizer::PassToken CreateDeadBranchElimPass();
 
 // Creates a pass to consolidate uniform references.
 // For each entry point function in the module, first change all constant index
-// access chain loads into equivalent composite extracts. Then consolidate 
+// access chain loads into equivalent composite extracts. Then consolidate
 // identical uniform loads into one uniform load. Finally, consolidate
 // identical uniform extracts into one uniform extract. This may require
 // moving a load or extract to a point which dominates all uses.
index f91e5a215d9b567ce621ba64d9d832abecb3d062..67d3b5de7fa2ba11c6479304bb013308856faa57 100644 (file)
@@ -44,6 +44,7 @@ add_library(SPIRV-Tools-opt
   pass.h
   passes.h
   pass_manager.h
+  eliminate_dead_functions_pass.h
   set_spec_constant_default_value_pass.h
   strength_reduction_pass.h
   strip_debug_info_pass.h
@@ -75,6 +76,7 @@ add_library(SPIRV-Tools-opt
   local_single_store_elim_pass.cpp
   local_ssa_elim_pass.cpp
   module.cpp
+  eliminate_dead_functions_pass.cpp
   set_spec_constant_default_value_pass.cpp
   optimizer.cpp
   mem_pass.cpp
diff --git a/source/opt/eliminate_dead_functions_pass.cpp b/source/opt/eliminate_dead_functions_pass.cpp
new file mode 100644 (file)
index 0000000..037dbd0
--- /dev/null
@@ -0,0 +1,61 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "eliminate_dead_functions_pass.h"
+
+#include <unordered_set>
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status EliminateDeadFunctionsPass::Process(ir::Module* module) {
+  bool modified = false;
+  module_ = module;
+
+  // Identify live functions first.  Those that are not live
+  // are dead.
+  std::unordered_set<const ir::Function*> live_function_set;
+  ProcessFunction mark_live = [&live_function_set](ir::Function* fp) {
+    live_function_set.insert(fp);
+    return false;
+  };
+  ProcessReachableCallTree(mark_live, module);
+
+  def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
+  FindNamedOrDecoratedIds();
+  for (auto funcIter = module->begin(); funcIter != module->end();) {
+    if (live_function_set.count(&*funcIter) == 0) {
+      modified = true;
+      EliminateFunction(&*funcIter);
+      funcIter = funcIter.Erase();
+    } else {
+      ++funcIter;
+    }
+  }
+
+  return modified ? Pass::Status::SuccessWithChange
+                  : Pass::Status::SuccessWithoutChange;
+}
+
+void EliminateDeadFunctionsPass::EliminateFunction(ir::Function* func) {
+  // Remove all of the instruction in the function body
+  func->ForEachInst(
+      [this](ir::Instruction* inst) {
+        KillNamesAndDecorates(inst);
+        def_use_mgr_->KillInst(inst);
+      },
+      true);
+}
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/eliminate_dead_functions_pass.h b/source/opt/eliminate_dead_functions_pass.h
new file mode 100644 (file)
index 0000000..a7d0742
--- /dev/null
@@ -0,0 +1,39 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_
+#define LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_
+
+#include "def_use_manager.h"
+#include "function.h"
+#include "mem_pass.h"
+#include "module.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class EliminateDeadFunctionsPass : public MemPass {
+ public:
+  const char* name() const override { return "eliminate-dead-functions"; }
+  Status Process(ir::Module*) override;
+
+ private:
+  void EliminateFunction(ir::Function* func);
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_
index 16600ba223cced157494f100b395225df49fe733..66fbad5382eb8935b17245269cb0855b9d47e9d8 100644 (file)
@@ -95,6 +95,13 @@ Optimizer::PassToken CreateStripDebugInfoPass() {
       MakeUnique<opt::StripDebugInfoPass>());
 }
 
+
+Optimizer::PassToken CreateEliminateDeadFunctionsPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::EliminateDeadFunctionsPass>());
+}
+
+
 Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
     const std::unordered_map<uint32_t, std::string>& id_value_map) {
   return MakeUnique<Optimizer::PassToken::Impl>(
index 626b3239c38ab0937876fe797b91280c94c0b8e2..c9d7c11ff08f0557f9b5605b8c4ee3dd5d493f19 100644 (file)
@@ -25,41 +25,78 @@ namespace {
 
 const uint32_t kEntryPointFunctionIdInIdx = 1;
 
-}  // namespace anonymous
+}  // namespace
 
-void Pass::AddCalls(ir::Function* func,
-    std::queue<uint32_t>* todo) {
+void Pass::AddCalls(ir::Function* func, std::queue<uint32_t>* todo) {
   for (auto bi = func->begin(); bi != func->end(); ++bi)
     for (auto ii = bi->begin(); ii != bi->end(); ++ii)
       if (ii->opcode() == SpvOpFunctionCall)
         todo->push(ii->GetSingleWordInOperand(0));
 }
 
-bool Pass::ProcessEntryPointCallTree(
-    ProcessFunction& pfn, ir::Module* module) {
+bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, ir::Module* module) {
   // Map from function's result id to function
   std::unordered_map<uint32_t, ir::Function*> id2function;
-  for (auto& fn : *module)
-    id2function[fn.result_id()] = &fn;
+  for (auto& fn : *module) id2function[fn.result_id()] = &fn;
+
+  // Collect all of the entry points as the roots.
+  std::queue<uint32_t> roots;
+  for (auto& e : module->entry_points())
+    roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
+  return ProcessCallTreeFromRoots(pfn, id2function, &roots);
+}
+
+bool Pass::ProcessReachableCallTree(ProcessFunction& pfn, ir::Module* module) {
+  // Map from function's result id to function
+  std::unordered_map<uint32_t, ir::Function*> id2function;
+  for (auto& fn : *module) id2function[fn.result_id()] = &fn;
+
+  std::queue<uint32_t> roots;
+
+  // Add all entry points since they can be reached from outside the module.
+  for (auto& e : module->entry_points())
+    roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
+
+  // Add all exported functions since they can be reached from outside the
+  // module.
+  for (auto& a : module->annotations()) {
+    // TODO: Handle group decorations as well.  Currently not generate by any
+    // front-end, but could be coming.
+    if (a.opcode() == SpvOp::SpvOpDecorate) {
+      if (a.GetSingleWordOperand(1) ==
+          SpvDecoration::SpvDecorationLinkageAttributes) {
+        uint32_t lastOperand = a.NumOperands() - 1;
+        if (a.GetSingleWordOperand(lastOperand) ==
+            SpvLinkageType::SpvLinkageTypeExport) {
+          uint32_t id = a.GetSingleWordOperand(0);
+          if (id2function.count(id) != 0) roots.push(id);
+        }
+      }
+    }
+  }
+
+  return ProcessCallTreeFromRoots(pfn, id2function, &roots);
+}
+
+bool Pass::ProcessCallTreeFromRoots(
+    ProcessFunction& pfn,
+    const std::unordered_map<uint32_t, ir::Function*>& id2function,
+    std::queue<uint32_t>* roots) {
   // Process call tree
   bool modified = false;
-  std::queue<uint32_t> todo;
   std::unordered_set<uint32_t> done;
-  for (auto& e : module->entry_points())
-    todo.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
-  while (!todo.empty()) {
-    const uint32_t fi = todo.front();
-    if (done.find(fi) == done.end()) {
-      ir::Function* fn = id2function[fi];
+
+  while (!roots->empty()) {
+    const uint32_t fi = roots->front();
+    roots->pop();
+    if (done.insert(fi).second) {
+      ir::Function* fn = id2function.at(fi);
       modified = pfn(fn) || modified;
-      done.insert(fi);
-      AddCalls(fn, &todo);
+      AddCalls(fn, roots);
     }
-    todo.pop();
   }
   return modified;
 }
-
 }  // namespace opt
 }  // namespace spvtools
 
index 897203f152207b7c3a2f3a5bc707faa61bb48964..78622705fdccd909e9f86a3249759bb4ea2eac6f 100644 (file)
@@ -67,9 +67,25 @@ class Pass {
   // Add to |todo| all ids of functions called in |func|.
   void AddCalls(ir::Function* func, std::queue<uint32_t>* todo);
 
-  // 
+  // Applies |pfn| to every function in the call trees that are rooted at the
+  // entry points.  Returns true if any call |pfn| returns true.  By convention
+  // |pfn| should return true if it modified the module.
   bool ProcessEntryPointCallTree(ProcessFunction& pfn, ir::Module* module);
 
+  // Applies |pfn| to every function in the call trees rooted at the entry
+  // points and exported functions.  Returns true if any call |pfn| returns
+  // true.  By convention |pfn| should return true if it modified the module.
+  bool ProcessReachableCallTree(ProcessFunction& pfn, ir::Module* module);
+
+  // Applies |pfn| to every function in the call trees rooted at the elements of
+  // |roots|.  Returns true if any call to |pfn| returns true.  By convention
+  // |pfn| should return true if it modified the module.  After returning
+  // |roots| will be empty.
+  bool ProcessCallTreeFromRoots(
+      ProcessFunction& pfn,
+      const std::unordered_map<uint32_t, ir::Function*>& id2function,
+      std::queue<uint32_t>* roots);
+
   // Processes the given |module|. Returns Status::Failure if errors occur when
   // processing. Returns the corresponding Status::Success if processing is
   // succesful to indicate whether changes are made to the module.
index 572387fde7101b42afac401c67a7e25c85064400..fb5c3ecc8d55f90414571ef8ea22c478d9de0b82 100644 (file)
@@ -38,5 +38,6 @@
 #include "strength_reduction_pass.h"
 #include "strip_debug_info_pass.h"
 #include "unify_const_pass.h"
+#include "eliminate_dead_functions_pass.h"
 
 #endif  // LIBSPIRV_OPT_PASSES_H_
index 64e46b8af429252c2240f0b192862e9b4e457b98..16f50ed390675ed22d400b7f6f06c51e0c35091a 100644 (file)
@@ -113,6 +113,16 @@ add_spvtools_unittest(TARGET pass_eliminate_dead_const
   LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET pass_eliminate_dead_functions
+  SRCS eliminate_dead_functions_test.cpp pass_utils.cpp
+  LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_pass
+  SRCS pass_test.cpp pass_utils.cpp
+  LIBS SPIRV-Tools-opt
+)
+
 add_spvtools_unittest(TARGET pass_utils
   SRCS utils_test.cpp pass_utils.cpp
   LIBS SPIRV-Tools-opt
diff --git a/test/opt/eliminate_dead_functions_test.cpp b/test/opt/eliminate_dead_functions_test.cpp
new file mode 100644 (file)
index 0000000..6f40a91
--- /dev/null
@@ -0,0 +1,206 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "assembly_builder.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::HasSubstr;
+
+using EliminateDeadFunctionsBasicTest = PassTest<::testing::Test>;
+
+TEST_F(EliminateDeadFunctionsBasicTest, BasicDeleteDeadFunction) {
+  // The function Dead should be removed because it is never called.
+  const std::vector<const char*> common_code = {
+      // clang-format off
+               "OpCapability Shader",
+               "OpMemoryModel Logical GLSL450",
+               "OpEntryPoint Fragment %main \"main\"",
+               "OpName %main \"main\"",
+               "OpName %Live \"Live\"",
+       "%void = OpTypeVoid",
+          "%7 = OpTypeFunction %void",
+       "%main = OpFunction %void None %7",
+         "%15 = OpLabel",
+         "%16 = OpFunctionCall %void %Live",
+         "%17 = OpFunctionCall %void %Live",
+               "OpReturn",
+               "OpFunctionEnd",
+  "%Live = OpFunction %void None %7",
+         "%20 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd"
+      // clang-format on
+  };
+
+  const std::vector<const char*> dead_function = {
+      // clang-format off
+      "%Dead = OpFunction %void None %7",
+         "%19 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd",
+      // clang-format on
+  };
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<opt::EliminateDeadFunctionsPass>(
+      JoinAllInsts(Concat(common_code, dead_function)),
+      JoinAllInsts(common_code), /* skip_nop = */ true);
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepLiveFunction) {
+  // Everything is reachable from an entry point, so no functions should be
+  // deleted.
+  const std::vector<const char*> text = {
+      // clang-format off
+               "OpCapability Shader",
+               "OpMemoryModel Logical GLSL450",
+               "OpEntryPoint Fragment %main \"main\"",
+               "OpName %main \"main\"",
+               "OpName %Live1 \"Live1\"",
+               "OpName %Live2 \"Live2\"",
+       "%void = OpTypeVoid",
+          "%7 = OpTypeFunction %void",
+       "%main = OpFunction %void None %7",
+         "%15 = OpLabel",
+         "%16 = OpFunctionCall %void %Live2",
+         "%17 = OpFunctionCall %void %Live1",
+               "OpReturn",
+               "OpFunctionEnd",
+      "%Live1 = OpFunction %void None %7",
+         "%19 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd",
+      "%Live2 = OpFunction %void None %7",
+         "%20 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd"
+      // clang-format on
+  };
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  std::string assembly = JoinAllInsts(text);
+  auto result = SinglePassRunAndDisassemble<opt::EliminateDeadFunctionsPass>(
+      assembly, /* skip_nop = */ true);
+  EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
+  EXPECT_EQ(assembly, std::get<0>(result));
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepExportFunctions) {
+  // All functions are reachable.  In particular, ExportedFunc and Constant are
+  // reachable because ExportedFunc is exported.  Nothing should be removed.
+  const std::vector<const char*> text = {
+      // clang-format off
+               "OpCapability Shader",
+               "OpCapability Linkage",
+               "OpMemoryModel Logical GLSL450",
+               "OpEntryPoint Fragment %main \"main\"",
+               "OpName %main \"main\"",
+               "OpName %ExportedFunc \"ExportedFunc\"",
+               "OpName %Live \"Live\"",
+               "OpDecorate %ExportedFunc LinkageAttributes \"ExportedFunc\" Export",
+       "%void = OpTypeVoid",
+          "%7 = OpTypeFunction %void",
+       "%main = OpFunction %void None %7",
+         "%15 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd",
+"%ExportedFunc = OpFunction %void None %7",
+         "%19 = OpLabel",
+         "%16 = OpFunctionCall %void %Live",
+               "OpReturn",
+               "OpFunctionEnd",
+  "%Live = OpFunction %void None %7",
+         "%20 = OpLabel",
+               "OpReturn",
+               "OpFunctionEnd"
+      // clang-format on
+  };
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  std::string assembly = JoinAllInsts(text);
+  auto result = SinglePassRunAndDisassemble<opt::EliminateDeadFunctionsPass>(
+      assembly, /* skip_nop = */ true);
+  EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
+  EXPECT_EQ(assembly, std::get<0>(result));
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, BasicRemoveDecorationsAndNames) {
+  // We want to remove the names and decorations associated with results that
+  // are removed.  This test will check for that.
+  const std::string text = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %main "main"
+               OpName %main "main"
+               OpName %Dead "Dead"
+               OpName %x "x"
+               OpName %y "y"
+               OpName %z "z"
+               OpDecorate %x RelaxedPrecision
+               OpDecorate %y RelaxedPrecision
+               OpDecorate %z RelaxedPrecision
+               OpDecorate %6 RelaxedPrecision
+               OpDecorate %7 RelaxedPrecision
+               OpDecorate %8 RelaxedPrecision
+       %void = OpTypeVoid
+         %10 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+    %float_1 = OpConstant %float 1
+       %main = OpFunction %void None %10
+         %14 = OpLabel
+               OpReturn
+               OpFunctionEnd
+       %Dead = OpFunction %void None %10
+         %15 = OpLabel
+          %x = OpVariable %_ptr_Function_float Function
+          %y = OpVariable %_ptr_Function_float Function
+          %z = OpVariable %_ptr_Function_float Function
+               OpStore %x %float_1
+               OpStore %y %float_1
+          %6 = OpLoad %float %x
+          %7 = OpLoad %float %y
+          %8 = OpFAdd %float %6 %7
+               OpStore %z %8
+               OpReturn
+               OpFunctionEnd)";
+
+  const std::string expected_output = R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_1 = OpConstant %float 1
+%main = OpFunction %void None %10
+%14 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<opt::EliminateDeadFunctionsPass>(text, expected_output,
+                                                         /* skip_nop = */ true);
+}
+}  // anonymous namespace
diff --git a/test/opt/pass_test.cpp b/test/opt/pass_test.cpp
new file mode 100644 (file)
index 0000000..a6298ca
--- /dev/null
@@ -0,0 +1,241 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "assembly_builder.h"
+#include "opt/pass.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+using namespace spvtools;
+class DummyPass : public opt::Pass {
+ public:
+  const char* name() const override { return "dummy-pass"; }
+  Status Process(ir::Module* module) override {
+    return module ? Status::SuccessWithoutChange : Status::Failure;
+  }
+};
+}  // namespace
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using PassClassTest = PassTest<::testing::Test>;
+
+TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
+  // Make sure we visit the entry point, and the function it calls.
+  // Do not visit Dead or Exported.
+  const std::string text = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %10 "main"
+               OpName %10 "main"
+               OpName %Dead "Dead"
+               OpName %11 "Constant"
+               OpName %ExportedFunc "ExportedFunc"
+               OpDecorate %ExportedFunc LinkageAttributes "ExportedFunc" Export
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+         %10 = OpFunction %void None %6
+         %14 = OpLabel
+         %15 = OpFunctionCall %void %11
+         %16 = OpFunctionCall %void %11
+               OpReturn
+               OpFunctionEnd
+         %11 = OpFunction %void None %6
+         %18 = OpLabel
+               OpReturn
+               OpFunctionEnd
+       %Dead = OpFunction %void None %6
+         %19 = OpLabel
+               OpReturn
+               OpFunctionEnd
+%ExportedFunc = OpFunction %void None %7
+         %20 = OpLabel
+         %21 = OpFunctionCall %void %11
+               OpReturn
+               OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::Module> module =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  DummyPass testPass;
+  std::vector<uint32_t> processed;
+  opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) {
+    processed.push_back(fp->result_id());
+    return false;
+  };
+  testPass.ProcessEntryPointCallTree(mark_visited, module.get());
+  EXPECT_THAT(processed, UnorderedElementsAre(10, 11));
+}
+
+TEST_F(PassClassTest, BasicVisitReachable) {
+  // Make sure we visit the entry point, exported function, and the function
+  // they call. Do not visit Dead.
+  const std::string text = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %10 "main"
+               OpName %10 "main"
+               OpName %Dead "Dead"
+               OpName %11 "Constant"
+               OpName %12 "ExportedFunc"
+               OpName %13 "Constant2"
+               OpDecorate %12 LinkageAttributes "ExportedFunc" Export
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+         %10 = OpFunction %void None %6
+         %14 = OpLabel
+         %15 = OpFunctionCall %void %11
+         %16 = OpFunctionCall %void %11
+               OpReturn
+               OpFunctionEnd
+         %11 = OpFunction %void None %6
+         %18 = OpLabel
+               OpReturn
+               OpFunctionEnd
+       %Dead = OpFunction %void None %6
+         %19 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %12 = OpFunction %void None %9
+         %20 = OpLabel
+         %21 = OpFunctionCall %void %13
+               OpReturn
+               OpFunctionEnd
+         %13 = OpFunction %void None %6
+         %22 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::Module> module =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  DummyPass testPass;
+  std::vector<uint32_t> processed;
+  opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) {
+    processed.push_back(fp->result_id());
+    return false;
+  };
+  testPass.ProcessReachableCallTree(mark_visited, module.get());
+  EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13));
+}
+
+TEST_F(PassClassTest, BasicVisitOnlyOnce) {
+  // Make sure we visit %11 only once, even if it is called from two different
+  // functions.
+  const std::string text = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %10 "main" %gl_FragColor
+               OpName %10 "main"
+               OpName %Dead "Dead"
+               OpName %11 "Constant"
+               OpName %12 "ExportedFunc"
+               OpDecorate %12 LinkageAttributes "ExportedFunc" Export
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+         %10 = OpFunction %void None %6
+         %14 = OpLabel
+         %15 = OpFunctionCall %void %11
+         %16 = OpFunctionCall %void %12
+               OpReturn
+               OpFunctionEnd
+         %11 = OpFunction %void None %6
+         %18 = OpLabel
+         %19 = OpFunctionCall %void %12
+               OpReturn
+               OpFunctionEnd
+       %Dead = OpFunction %void None %6
+         %20 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %12 = OpFunction %void None %9
+         %21 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::Module> module =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  DummyPass testPass;
+  std::vector<uint32_t> processed;
+  opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) {
+    processed.push_back(fp->result_id());
+    return false;
+  };
+  testPass.ProcessReachableCallTree(mark_visited, module.get());
+  EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12));
+}
+
+TEST_F(PassClassTest, BasicDontVisitExportedVariable) {
+  // Make sure we only visit functions and not exported variables.
+  const std::string text = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %10 "main" %gl_FragColor
+               OpExecutionMode %10 OriginUpperLeft
+               OpSource GLSL 150
+               OpName %10 "main"
+               OpName %Dead "Dead"
+               OpName %11 "Constant"
+               OpName %12 "export_var"
+               OpDecorate %12 LinkageAttributes "export_var" Export
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+  %float_1 = OpConstant %float 1
+         %12 = OpVariable %float Output
+         %10 = OpFunction %void None %6
+         %14 = OpLabel
+               OpStore %12 %float_1
+               OpReturn
+               OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::Module> module =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  DummyPass testPass;
+  std::vector<uint32_t> processed;
+  opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) {
+    processed.push_back(fp->result_id());
+    return false;
+  };
+  testPass.ProcessReachableCallTree(mark_visited, module.get());
+  EXPECT_THAT(processed, UnorderedElementsAre(10));
+}
+}  // namespace
index 1a5334758c7254da133ade2150f218942cf92ef3..ac037c51fc482c1e1e595821e6c7f402faea50ad 100644 (file)
@@ -108,6 +108,9 @@ Options:
                Convert conditional branches with constant condition to the
                indicated unconditional brranch. Delete all resulting dead
                code. Performed only on entry point call tree functions.
+  --eliminate-dead-functions
+               Deletes functions that cannot be reached from entry points or
+               exported functions.
   --merge-blocks
                Join two blocks into a single block if the second has the
                first as its only predecessor. Performed only on entry point
@@ -194,6 +197,8 @@ int main(int argc, char** argv) {
         optimizer.RegisterPass(CreateBlockMergePass());
       } else if (0 == strcmp(cur_arg, "--eliminate-dead-branches")) {
         optimizer.RegisterPass(CreateDeadBranchElimPass());
+      } else if (0 == strcmp(cur_arg, "--eliminate-dead-functions")) {
+        optimizer.RegisterPass(CreateEliminateDeadFunctionsPass());
       } else if (0 == strcmp(cur_arg, "--eliminate-local-multi-store")) {
         optimizer.RegisterPass(CreateLocalMultiStoreElimPass());
       } else if (0 == strcmp(cur_arg, "--eliminate-common-uniform")) {