Add loop unswitch pass.
authorVictor Lomuller <victor@codeplay.com>
Mon, 12 Feb 2018 21:42:15 +0000 (21:42 +0000)
committerDiego Novillo <dnovillo@google.com>
Tue, 27 Feb 2018 13:52:46 +0000 (08:52 -0500)
It moves all conditional branching and switch whose conditions are loop
invariant and uniform. Before performing the loop unswitch we check that
the loop does not contain any instruction that would prevent it
(barriers, group instructions etc.).

24 files changed:
Android.mk
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/basic_block.cpp
source/opt/basic_block.h
source/opt/cfg.h
source/opt/dominator_tree.cpp
source/opt/dominator_tree.h
source/opt/function.h
source/opt/ir_builder.h
source/opt/iterator.h
source/opt/loop_descriptor.cpp
source/opt/loop_descriptor.h
source/opt/loop_unswitch_pass.cpp [new file with mode: 0644]
source/opt/loop_unswitch_pass.h [new file with mode: 0644]
source/opt/loop_utils.cpp
source/opt/loop_utils.h
source/opt/mem_pass.cpp
source/opt/optimizer.cpp
source/opt/passes.h
test/opt/loop_optimizations/CMakeLists.txt
test/opt/loop_optimizations/loop_descriptions.cpp
test/opt/loop_optimizations/unswitch.cpp [new file with mode: 0644]
tools/opt/opt.cpp

index 4429ad1..074ceb2 100644 (file)
@@ -101,6 +101,8 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/local_ssa_elim_pass.cpp \
                source/opt/loop_descriptor.cpp \
                source/opt/loop_unroller.cpp \
+               source/opt/loop_unswitch_pass.cpp \
+               source/opt/loop_utils.cpp \
                source/opt/mem_pass.cpp \
                source/opt/merge_return_pass.cpp \
                source/opt/module.cpp \
index adb014a..bedf436 100644 (file)
@@ -468,6 +468,12 @@ Optimizer::PassToken CreateLocalRedundancyEliminationPass();
 // the loops preheader.
 Optimizer::PassToken CreateLoopInvariantCodeMotionPass();
 
+// Creates a loop unswitch pass.
+// This pass will look for loop independent branch conditions and move the
+// condition out of the loop and version the loop based on the taken branch.
+// Works best after LICM and local multi store elimination pass.
+Optimizer::PassToken CreateLoopUnswitchPass();
+
 // Create global value numbering pass.
 // This pass will look for instructions where the same value is computed on all
 // paths leading to the instruction.  Those instructions are deleted.
index 0194851..854c950 100644 (file)
@@ -60,6 +60,7 @@ add_library(SPIRV-Tools-opt
   loop_descriptor.h
   loop_unroller.h
   loop_utils.h
+  loop_unswitch_pass.h
   make_unique.h
   mem_pass.h
   merge_return_pass.h
@@ -132,6 +133,7 @@ add_library(SPIRV-Tools-opt
   loop_descriptor.cpp
   loop_utils.cpp
   loop_unroller.cpp
+  loop_unswitch_pass.cpp
   mem_pass.cpp
   merge_return_pass.cpp
   module.cpp
index b07696b..65030ea 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "basic_block.h"
 #include "function.h"
+#include "ir_context.h"
 #include "module.h"
 #include "reflect.h"
 
@@ -89,6 +90,14 @@ Instruction* BasicBlock::GetLoopMergeInst() {
   return nullptr;
 }
 
+void BasicBlock::KillAllInsts(bool killLabel) {
+  ForEachInst([killLabel](ir::Instruction* ip) {
+    if (killLabel || ip->opcode() != SpvOpLabel) {
+      ip->context()->KillInst(ip);
+    }
+  });
+}
+
 void BasicBlock::ForEachSuccessorLabel(
     const std::function<void(const uint32_t)>& f) const {
   const auto br = &insts_.back();
index d0186e6..5f3d393 100644 (file)
@@ -171,6 +171,10 @@ class BasicBlock {
   // Returns true if this basic block exits this function or aborts execution.
   bool IsReturnOrAbort() const { return ctail()->IsReturnOrAbort(); }
 
+  // Kill all instructions in this block. Whether or not to kill the label is
+  // indicated by |killLabel|.
+  void KillAllInsts(bool killLabel);
+
  private:
   // The enclosing function.
   Function* function_;
index 53dddd2..b21a273 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "basic_block.h"
 
+#include <algorithm>
 #include <list>
 #include <unordered_map>
 #include <unordered_set>
@@ -83,6 +84,22 @@ class CFG {
     AddEdges(blk);
   }
 
+  // Removes from the CFG any mapping for the basic block id |blk_id|.
+  void ForgetBlock(const ir::BasicBlock* blk) {
+    id2block_.erase(blk->id());
+    label2preds_.erase(blk->id());
+    blk->ForEachSuccessorLabel(
+        [blk, this](uint32_t succ_id) { RemoveEdge(blk->id(), succ_id); });
+  }
+
+  void RemoveEdge(uint32_t pred_blk_id, uint32_t succ_blk_id) {
+    auto pred_it = label2preds_.find(succ_blk_id);
+    if (pred_it == label2preds_.end()) return;
+    auto& preds_list = pred_it->second;
+    auto it = std::find(preds_list.begin(), preds_list.end(), pred_blk_id);
+    if (it != preds_list.end()) preds_list.erase(it);
+  }
+
   // Registers |blk| to all of its successors.
   void AddEdges(ir::BasicBlock* blk);
 
index c22d743..776adf4 100644 (file)
@@ -358,6 +358,10 @@ void DominatorTree::InitializeTree(const ir::Function* f, const ir::CFG& cfg) {
     second->children_.push_back(first);
   }
 
+  ResetDFNumbering();
+}
+
+void DominatorTree::ResetDFNumbering() {
   int index = 0;
   auto preFunc = [&index](const DominatorTreeNode* node) {
     const_cast<DominatorTreeNode*>(node)->dfs_num_pre_ = ++index;
index 5221eea..39d5e02 100644 (file)
@@ -15,6 +15,7 @@
 #ifndef LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_
 #define LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_
 
+#include <algorithm>
 #include <cstdint>
 #include <map>
 #include <utility>
@@ -195,7 +196,9 @@ class DominatorTree {
   }
 
   // Returns true if the basic block id |a| is reachable by this tree.
-  bool ReachableFromRoots(uint32_t a) const;
+  bool ReachableFromRoots(uint32_t a) const {
+    return GetTreeNode(a) != nullptr;
+  }
 
   // Returns true if this tree is a post dominator tree.
   bool IsPostDominator() const { return postdominator_; }
@@ -267,11 +270,14 @@ class DominatorTree {
     return &node_iter->second;
   }
 
- private:
   // Adds the basic block |bb| to the tree structure if it doesn't already
   // exist.
   DominatorTreeNode* GetOrInsertNode(ir::BasicBlock* bb);
 
+  // Recomputes the DF numbering of the tree.
+  void ResetDFNumbering();
+
+ private:
   // Wrapper function which gets the list of pairs of each BasicBlocks to its
   // immediately  dominating BasicBlock and stores the result in the the edges
   // parameter.
index 0da62a8..17e0637 100644 (file)
@@ -59,6 +59,10 @@ class Function {
   inline void AddParameter(std::unique_ptr<Instruction> p);
   // Appends a basic block to this function.
   inline void AddBasicBlock(std::unique_ptr<BasicBlock> b);
+  // Appends a basic block to this function at the position |ip|.
+  inline void AddBasicBlock(std::unique_ptr<BasicBlock> b, iterator ip);
+  template <typename T>
+  inline void AddBasicBlocks(T begin, T end, iterator ip);
 
   // Saves the given function end instruction.
   inline void SetFunctionEnd(std::unique_ptr<Instruction> end_inst);
@@ -73,6 +77,11 @@ class Function {
   // Returns function's return type id
   inline uint32_t type_id() const { return def_inst_->type_id(); }
 
+  // Returns the basic block container for this function.
+  const std::vector<std::unique_ptr<BasicBlock>>* GetBlocks() const {
+    return &blocks_;
+  }
+
   // Returns the entry basic block for this function.
   const std::unique_ptr<BasicBlock>& entry() const { return blocks_.front(); }
 
@@ -123,7 +132,18 @@ inline void Function::AddParameter(std::unique_ptr<Instruction> p) {
 }
 
 inline void Function::AddBasicBlock(std::unique_ptr<BasicBlock> b) {
-  blocks_.emplace_back(std::move(b));
+  AddBasicBlock(std::move(b), end());
+}
+
+inline void Function::AddBasicBlock(std::unique_ptr<BasicBlock> b,
+                                    iterator ip) {
+  ip.InsertBefore(std::move(b));
+}
+
+template <typename T>
+inline void Function::AddBasicBlocks(T src_begin, T src_end, iterator ip) {
+  blocks_.insert(ip.Get(), std::make_move_iterator(src_begin),
+                 std::make_move_iterator(src_end));
 }
 
 inline void Function::SetFunctionEnd(std::unique_ptr<Instruction> end_inst) {
index a1a1d1e..aa722cb 100644 (file)
@@ -105,6 +105,44 @@ class InstructionBuilder {
     return AddInstruction(std::move(new_branch));
   }
 
+  // Creates a new switch instruction and the associated selection merge
+  // instruction if requested.
+  // The id |selector_id| is the id of the selector instruction, must be of
+  // type int.
+  // The id |default_id| is the id of the default basic block to branch to.
+  // The vector |targets| is the pair of literal/branch id.
+  // The id |merge_id| is the id of the merge basic block for the selection
+  // merge instruction. If |merge_id| equals kInvalidId then no selection merge
+  // instruction will be created.
+  // The value |selection_control| is the selection control flag for the
+  // selection merge instruction.
+  // Note that the user must make sure the final basic block is
+  // well formed.
+  ir::Instruction* AddSwitch(
+      uint32_t selector_id, uint32_t default_id,
+      const std::vector<std::pair<std::vector<uint32_t>, uint32_t>>& targets,
+      uint32_t merge_id = kInvalidId,
+      uint32_t selection_control = SpvSelectionControlMaskNone) {
+    if (merge_id != kInvalidId) {
+      AddSelectionMerge(merge_id, selection_control);
+    }
+    std::vector<ir::Operand> operands;
+    operands.emplace_back(
+        ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {selector_id}});
+    operands.emplace_back(
+        ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {default_id}});
+    for (auto& target : targets) {
+      operands.emplace_back(
+          ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
+                      target.first});
+      operands.emplace_back(ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
+                                        {target.second}});
+    }
+    std::unique_ptr<ir::Instruction> new_switch(
+        new ir::Instruction(GetContext(), SpvOpSwitch, 0, 0, operands));
+    return AddInstruction(std::move(new_switch));
+  }
+
   // Creates a phi instruction.
   // The id |type| must be the id of the phi instruction's type.
   // The vector |incomings| must be a sequence of pairs of <definition id,
@@ -215,6 +253,14 @@ class InstructionBuilder {
     return AddInstruction(std::move(new_inst));
   }
 
+  // Creates an unreachable instruction.
+  ir::Instruction* AddUnreachable() {
+    std::unique_ptr<ir::Instruction> select(
+        new ir::Instruction(GetContext(), SpvOpUnreachable, 0, 0,
+                            std::initializer_list<ir::Operand>{}));
+    return AddInstruction(std::move(select));
+  }
+
   // Inserts the new instruction before the insertion point.
   ir::Instruction* AddInstruction(std::unique_ptr<ir::Instruction>&& insn) {
     ir::Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));
index 52a8d86..d43dfbe 100644 (file)
@@ -99,6 +99,14 @@ class UptrVectorIterator
   inline typename std::enable_if<!IsConstForMethod, UptrVectorIterator>::type
   Erase();
 
+  // Returns the underlying iterator.
+  UnderlyingIterator Get() const { return iterator_; }
+
+  // Returns a valid end iterator for the underlying container.
+  UptrVectorIterator End() const {
+    return UptrVectorIterator(container_, container_->end());
+  }
+
  private:
   UptrVector* container_;        // The container we are manipulating.
   UnderlyingIterator iterator_;  // The raw iterator from the container.
index 131363c..60d9468 100644 (file)
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "opt/loop_descriptor.h"
+#include <algorithm>
 #include <iostream>
 #include <type_traits>
 #include <utility>
@@ -245,11 +246,10 @@ bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
   assert(bb->GetParent() && "The basic block does not belong to a function");
   opt::DominatorAnalysis* dom_analysis =
       context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg());
-  if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false;
+  if (dom_analysis->IsReachable(bb) &&
+      !dom_analysis->Dominates(GetHeaderBlock(), bb))
+    return false;
 
-  opt::PostDominatorAnalysis* postdom_analysis =
-      context_->GetPostDominatorAnalysis(bb->GetParent(), *context_->cfg());
-  if (!postdom_analysis->Dominates(GetMergeBlock(), bb)) return false;
   return true;
 }
 
@@ -378,6 +378,17 @@ void Loop::SetMergeBlock(BasicBlock* merge) {
   }
 }
 
+void Loop::SetPreHeaderBlock(BasicBlock* preheader) {
+  assert(!IsInsideLoop(preheader) && "The preheader block is in the loop");
+  assert(preheader->tail()->opcode() == SpvOpBranch &&
+         "The preheader block does not unconditionally branch to the header "
+         "block");
+  assert(preheader->tail()->GetSingleWordOperand(0) == GetHeaderBlock()->id() &&
+         "The preheader block does not unconditionally branch to the header "
+         "block");
+  loop_preheader_ = preheader;
+}
+
 void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
   ir::CFG* cfg = context_->cfg();
   exit_blocks->clear();
@@ -412,6 +423,43 @@ void Loop::GetMergingBlocks(
   }
 }
 
+namespace {
+
+static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) {
+  for (ir::Instruction& inst : *bb) {
+    if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst))
+      return false;
+  }
+
+  return true;
+}
+
+}  // namespace
+
+bool Loop::IsSafeToClone() const {
+  ir::CFG& cfg = *context_->cfg();
+
+  for (uint32_t bb_id : GetBlocks()) {
+    BasicBlock* bb = cfg.block(bb_id);
+    assert(bb);
+    if (!IsBasicBlockSafeToClone(context_, bb)) return false;
+  }
+
+  // Look at the merge construct.
+  if (GetHeaderBlock()->GetLoopMergeInst()) {
+    std::unordered_set<uint32_t> blocks;
+    GetMergingBlocks(&blocks);
+    blocks.erase(GetMergeBlock()->id());
+    for (uint32_t bb_id : blocks) {
+      BasicBlock* bb = cfg.block(bb_id);
+      assert(bb);
+      if (!IsBasicBlockSafeToClone(context_, bb)) return false;
+    }
+  }
+
+  return true;
+}
+
 bool Loop::IsLCSSA() const {
   ir::CFG* cfg = context_->cfg();
   opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
@@ -482,7 +530,8 @@ void Loop::ComputeLoopStructuredOrder(
     ordered_loop_blocks->push_back(loop_merge_);
 }
 
-LoopDescriptor::LoopDescriptor(const Function* f) : loops_() {
+LoopDescriptor::LoopDescriptor(const Function* f)
+    : loops_(), dummy_top_loop_(nullptr) {
   PopulateList(f);
 }
 
@@ -503,6 +552,17 @@ void LoopDescriptor::PopulateList(const Function* f) {
        ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) {
     Instruction* merge_inst = node.bb_->GetLoopMergeInst();
     if (merge_inst) {
+      bool all_backedge_unreachable = true;
+      for (uint32_t pid : context->cfg()->preds(node.bb_->id())) {
+        if (dom_analysis->IsReachable(pid) &&
+            dom_analysis->Dominates(node.bb_->id(), pid)) {
+          all_backedge_unreachable = false;
+          break;
+        }
+      }
+      if (all_backedge_unreachable)
+        continue;  // ignore this one, we actually never branch back.
+
       // The id of the merge basic block of this loop.
       uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
 
@@ -888,5 +948,48 @@ void LoopDescriptor::ClearLoops() {
   }
   loops_.clear();
 }
+
+// Adds a new loop nest to the descriptor set.
+ir::Loop* LoopDescriptor::AddLoopNest(std::unique_ptr<ir::Loop> new_loop) {
+  ir::Loop* loop = new_loop.release();
+  if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
+  // Iterate from inner to outer most loop, adding basic block to loop mapping
+  // as we go.
+  for (ir::Loop& current_loop :
+       make_range(iterator::begin(loop), iterator::end(nullptr))) {
+    loops_.push_back(&current_loop);
+    for (uint32_t bb_id : current_loop.GetBlocks())
+      basic_block_to_loop_.insert(std::make_pair(bb_id, &current_loop));
+  }
+
+  return loop;
+}
+
+void LoopDescriptor::RemoveLoop(ir::Loop* loop) {
+  ir::Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_;
+  parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(),
+                                        parent->nested_loops_.end(), loop));
+  std::for_each(
+      loop->nested_loops_.begin(), loop->nested_loops_.end(),
+      [loop](ir::Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); });
+  parent->nested_loops_.insert(parent->nested_loops_.end(),
+                               loop->nested_loops_.begin(),
+                               loop->nested_loops_.end());
+  for (uint32_t bb_id : loop->GetBlocks()) {
+    ir::Loop* l = FindLoopForBasicBlock(bb_id);
+    if (l == loop) {
+      SetBasicBlockToLoop(bb_id, l->GetParent());
+    } else {
+      ForgetBasicBlock(bb_id);
+    }
+  }
+
+  LoopContainerType::iterator it =
+      std::find(loops_.begin(), loops_.end(), loop);
+  assert(it != loops_.end());
+  delete loop;
+  loops_.erase(it);
+}
+
 }  // namespace ir
 }  // namespace spvtools
index d0421d6..05acce2 100644 (file)
@@ -47,8 +47,8 @@ class Loop {
   using const_iterator = ChildrenList::const_iterator;
   using BasicBlockListTy = std::unordered_set<uint32_t>;
 
-  Loop()
-      : context_(nullptr),
+  explicit Loop(IRContext* context)
+      : context_(context),
         loop_header_(nullptr),
         loop_continue_(nullptr),
         loop_merge_(nullptr),
@@ -59,6 +59,8 @@ class Loop {
   Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header,
        BasicBlock* continue_target, BasicBlock* merge_target);
 
+  ~Loop() {}
+
   // Iterators over the immediate sub-loops.
   inline iterator begin() { return nested_loops_.begin(); }
   inline iterator end() { return nested_loops_.end(); }
@@ -115,6 +117,11 @@ class Loop {
 
   // Returns the loop pre-header.
   inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; }
+  // Sets |preheader| as the loop preheader block. A preheader block must have
+  // the following properties:
+  //  - |merge| must not be in the loop;
+  //  - have an unconditional branch to the loop header.
+  void SetPreHeaderBlock(BasicBlock* preheader);
 
   // Returns the loop pre-header, if there is no suitable preheader it will be
   // created.
@@ -190,7 +197,16 @@ class Loop {
   // Adds the Basic Block with |id| to this loop and its parents.
   void AddBasicBlock(uint32_t id) {
     for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
-      loop_basic_blocks_.insert(id);
+      loop->loop_basic_blocks_.insert(id);
+    }
+  }
+
+  // Removes the Basic Block id |bb_id| from this loop and its parents.
+  // It the user responsibility to make sure the removed block is not a merge,
+  // header or continue block.
+  void RemoveBasicBlock(uint32_t bb_id) {
+    for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
+      loop->loop_basic_blocks_.erase(bb_id);
     }
   }
 
@@ -264,6 +280,10 @@ class Loop {
     return true;
   }
 
+  // Checks if the loop contains any instruction that will prevent it from being
+  // cloned. If the loop is structured, the merge construct is also considered.
+  bool IsSafeToClone() const;
+
   // Sets the parent loop of this loop, that is, a loop which contains this loop
   // as a nested child loop.
   inline void SetParent(Loop* parent) { parent_ = parent; }
@@ -384,7 +404,7 @@ class LoopDescriptor {
   // Disable copy constructor, to avoid double-free on destruction.
   LoopDescriptor(const LoopDescriptor&) = delete;
   // Move constructor.
-  LoopDescriptor(LoopDescriptor&& other) {
+  LoopDescriptor(LoopDescriptor&& other) : dummy_top_loop_(nullptr) {
     // We need to take ownership of the Loop objects in the other
     // LoopDescriptor, to avoid double-free.
     loops_ = std::move(other.loops_);
@@ -446,6 +466,28 @@ class LoopDescriptor {
   // for addition with AddLoop or MarkLoopForRemoval.
   void PostModificationCleanup();
 
+  // Removes the basic block id |bb_id| from the block to loop mapping.
+  inline void ForgetBasicBlock(uint32_t bb_id) {
+    basic_block_to_loop_.erase(bb_id);
+  }
+
+  // Adds the loop |new_loop| and all its nested loops to the descriptor set.
+  // The object takes ownership of all the loops.
+  ir::Loop* AddLoopNest(std::unique_ptr<ir::Loop> new_loop);
+
+  // Remove the loop |loop|.
+  void RemoveLoop(ir::Loop* loop);
+
+  void SetAsTopLoop(ir::Loop* loop) {
+    assert(std::find(dummy_top_loop_.begin(), dummy_top_loop_.end(), loop) ==
+               dummy_top_loop_.end() &&
+           "already registered");
+    dummy_top_loop_.nested_loops_.push_back(loop);
+  }
+
+  Loop* GetDummyRootLoop() { return &dummy_top_loop_; }
+  const Loop* GetDummyRootLoop() const { return &dummy_top_loop_; }
+
  private:
   // TODO(dneto): This should be a vector of unique_ptr.  But VisualStudio 2013
   // is unable to compile it.
diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp
new file mode 100644 (file)
index 0000000..53f6299
--- /dev/null
@@ -0,0 +1,908 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "loop_unswitch_pass.h"
+
+#include <functional>
+#include <list>
+#include <memory>
+#include <type_traits>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "basic_block.h"
+#include "dominator_tree.h"
+#include "fold.h"
+#include "function.h"
+#include "instruction.h"
+#include "ir_builder.h"
+#include "ir_context.h"
+#include "loop_descriptor.h"
+
+#include "loop_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+static const uint32_t kTypePointerStorageClassInIdx = 0;
+static const uint32_t kBranchCondTrueLabIdInIdx = 1;
+static const uint32_t kBranchCondFalseLabIdInIdx = 2;
+
+}  // anonymous namespace
+
+namespace {
+
+// This class handle the unswitch procedure for a given loop.
+// The unswitch will not happen if:
+//  - The loop has any instruction that will prevent it;
+//  - The loop invariant condition is not uniform.
+class LoopUnswitch {
+ public:
+  LoopUnswitch(ir::IRContext* context, ir::Function* function, ir::Loop* loop,
+               ir::LoopDescriptor* loop_desc)
+      : function_(function),
+        loop_(loop),
+        loop_desc_(*loop_desc),
+        context_(context),
+        switch_block_(nullptr) {}
+
+  // Returns true if the loop can be unswitched.
+  // Can be unswitch if:
+  //  - The loop has no instructions that prevents it (such as barrier);
+  //  - The loop has one conditional branch or switch that do not depends on the
+  //  loop;
+  //  - The loop invariant condition is uniform;
+  bool CanUnswitchLoop() {
+    if (switch_block_) return true;
+    if (loop_->IsSafeToClone()) return false;
+
+    ir::CFG& cfg = *context_->cfg();
+
+    for (uint32_t bb_id : loop_->GetBlocks()) {
+      ir::BasicBlock* bb = cfg.block(bb_id);
+      if (bb->terminator()->IsBranch() &&
+          bb->terminator()->opcode() != SpvOpBranch) {
+        if (IsConditionLoopInvariant(bb->terminator())) {
+          switch_block_ = bb;
+          break;
+        }
+      }
+    }
+
+    return switch_block_;
+  }
+
+  // Return the iterator to the basic block |bb|.
+  ir::Function::iterator FindBasicBlockPosition(ir::BasicBlock* bb_to_find) {
+    ir::Function::iterator it = std::find_if(
+        function_->begin(), function_->end(),
+        [bb_to_find](const ir::BasicBlock& bb) { return bb_to_find == &bb; });
+    assert(it != function_->end() && "Basic Block not found");
+    return it;
+  }
+
+  // Creates a new basic block and insert it into the function |fn| at the
+  // position |ip|. This function preserves the def/use and instr to block
+  // managers.
+  ir::BasicBlock* CreateBasicBlock(ir::Function::iterator ip) {
+    analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+    ir::BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<ir::BasicBlock>(
+        new ir::BasicBlock(std::unique_ptr<ir::Instruction>(new ir::Instruction(
+            context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
+    bb->SetParent(function_);
+    def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
+    context_->set_instr_block(bb->GetLabelInst(), bb);
+
+    return bb;
+  }
+
+  // Unswitches |loop_|.
+  void PerformUnswitch() {
+    assert(CanUnswitchLoop() &&
+           "Cannot unswitch if there is not constant condition");
+    assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
+    assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
+
+    ir::CFG& cfg = *context_->cfg();
+    DominatorTree* dom_tree =
+        &context_->GetDominatorAnalysis(function_, *context_->cfg())
+             ->GetDomTree();
+    analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+    LoopUtils loop_utils(context_, loop_);
+
+    //////////////////////////////////////////////////////////////////////////////
+    // Step 1: Create the if merge block for structured modules.
+    //    To do so, the |loop_| merge block will become the if's one and we
+    //    create a merge for the loop. This will limit the amount of duplicated
+    //    code the structured control flow imposes.
+    //    For non structured program, the new loop will be connected to
+    //    the old loop's exit blocks.
+    //////////////////////////////////////////////////////////////////////////////
+
+    // Get the merge block if it exists.
+    ir::BasicBlock* if_merge_block = loop_->GetMergeBlock();
+    // The merge block is only created if the loop has a unique exit block. We
+    // have this guarantee for structured loops, for compute loop it will
+    // trivially help maintain both a structured-like form and LCSAA.
+    ir::BasicBlock* loop_merge_block =
+        if_merge_block
+            ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
+            : nullptr;
+    if (loop_merge_block) {
+      // Add the instruction and update managers.
+      opt::InstructionBuilder builder(
+          context_, loop_merge_block,
+          ir::IRContext::kAnalysisDefUse |
+              ir::IRContext::kAnalysisInstrToBlockMapping);
+      builder.AddBranch(if_merge_block->id());
+      builder.SetInsertPoint(&*loop_merge_block->begin());
+      cfg.RegisterBlock(loop_merge_block);
+      def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
+      // Update CFG.
+      if_merge_block->ForEachPhiInst(
+          [loop_merge_block, &builder, this](ir::Instruction* phi) {
+            ir::Instruction* cloned = phi->Clone(context_);
+            builder.AddInstruction(std::unique_ptr<ir::Instruction>(cloned));
+            phi->SetInOperand(0, {cloned->result_id()});
+            phi->SetInOperand(1, {loop_merge_block->id()});
+            for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
+              phi->RemoveInOperand(j);
+          });
+      // Copy the predecessor list (will get invalidated otherwise).
+      std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
+      for (uint32_t pid : preds) {
+        if (pid == loop_merge_block->id()) continue;
+        ir::BasicBlock* p_bb = cfg.block(pid);
+        p_bb->ForEachSuccessorLabel(
+            [if_merge_block, loop_merge_block](uint32_t* id) {
+              if (*id == if_merge_block->id()) *id = loop_merge_block->id();
+            });
+        cfg.AddEdge(pid, loop_merge_block->id());
+      }
+      cfg.RemoveNonExistingEdges(if_merge_block->id());
+      // Update loop descriptor.
+      if (ir::Loop* ploop = loop_->GetParent()) {
+        ploop->AddBasicBlock(loop_merge_block);
+        loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
+      }
+
+      // Update the dominator tree.
+      DominatorTreeNode* loop_merge_dtn =
+          dom_tree->GetOrInsertNode(loop_merge_block);
+      DominatorTreeNode* if_merge_block_dtn =
+          dom_tree->GetOrInsertNode(if_merge_block);
+      loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
+      loop_merge_dtn->children_.push_back(if_merge_block_dtn);
+      loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
+      if_merge_block_dtn->parent_->children_.erase(std::find(
+          if_merge_block_dtn->parent_->children_.begin(),
+          if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
+
+      loop_->SetMergeBlock(loop_merge_block);
+    }
+
+    ////////////////////////////////////////////////////////////////////////////
+    // Step 2: Build a new preheader for |loop_|, use the old one
+    //         for the constant branch.
+    ////////////////////////////////////////////////////////////////////////////
+
+    ir::BasicBlock* if_block = loop_->GetPreHeaderBlock();
+    // If this preheader is the parent loop header,
+    // we need to create a dedicated block for the if.
+    ir::BasicBlock* loop_pre_header =
+        CreateBasicBlock(++FindBasicBlockPosition(if_block));
+    opt::InstructionBuilder(context_, loop_pre_header,
+                            ir::IRContext::kAnalysisDefUse |
+                                ir::IRContext::kAnalysisInstrToBlockMapping)
+        .AddBranch(loop_->GetHeaderBlock()->id());
+
+    if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
+
+    // Update loop descriptor.
+    if (ir::Loop* ploop = loop_desc_[if_block]) {
+      ploop->AddBasicBlock(loop_pre_header);
+      loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
+    }
+
+    // Update the CFG.
+    cfg.RegisterBlock(loop_pre_header);
+    def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
+    cfg.AddEdge(if_block->id(), loop_pre_header->id());
+    cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
+
+    loop_->GetHeaderBlock()->ForEachPhiInst(
+        [loop_pre_header, if_block](ir::Instruction* phi) {
+          phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
+            if (*id == if_block->id()) {
+              *id = loop_pre_header->id();
+            }
+          });
+        });
+    loop_->SetPreHeaderBlock(loop_pre_header);
+
+    // Update the dominator tree.
+    DominatorTreeNode* loop_pre_header_dtn =
+        dom_tree->GetOrInsertNode(loop_pre_header);
+    DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
+    loop_pre_header_dtn->parent_ = if_block_dtn;
+    assert(
+        if_block_dtn->children_.size() == 1 &&
+        "A loop preheader should only have the header block as a child in the "
+        "dominator tree");
+    loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
+    if_block_dtn->children_.clear();
+    if_block_dtn->children_.push_back(loop_pre_header_dtn);
+
+    // Make domination queries valid.
+    dom_tree->ResetDFNumbering();
+
+    // Compute an ordered list of basic block to clone: loop blocks + pre-header
+    // + merge block.
+    loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
+
+    /////////////////////////////
+    // Do the actual unswitch: //
+    //   - Clone the loop      //
+    //   - Connect exits       //
+    //   - Specialize the loop //
+    /////////////////////////////
+
+    ir::Instruction* iv_condition = &*switch_block_->tail();
+    SpvOp iv_opcode = iv_condition->opcode();
+    ir::Instruction* condition =
+        def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
+
+    analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
+    const analysis::Type* cond_type =
+        context_->get_type_mgr()->GetType(condition->type_id());
+
+    // Build the list of value for which we need to clone and specialize the
+    // loop.
+    std::vector<std::pair<ir::Instruction*, ir::BasicBlock*>> constant_branch;
+    // Special case for the original loop
+    ir::Instruction* original_loop_constant_value;
+    ir::BasicBlock* original_loop_target;
+    if (iv_opcode == SpvOpBranchConditional) {
+      constant_branch.emplace_back(
+          cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
+          nullptr);
+      original_loop_constant_value =
+          cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
+    } else {
+      // We are looking to take the default branch, so we can't provide a
+      // specific value.
+      original_loop_constant_value = nullptr;
+      for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
+        constant_branch.emplace_back(
+            cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
+                cond_type, iv_condition->GetInOperand(i).words)),
+            nullptr);
+      }
+    }
+
+    // Get the loop landing pads.
+    std::unordered_set<uint32_t> if_merging_blocks;
+    std::function<bool(uint32_t)> is_from_original_loop;
+    if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
+      if_merging_blocks.insert(if_merge_block->id());
+      is_from_original_loop = [this](uint32_t id) {
+        return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
+      };
+    } else {
+      loop_->GetExitBlocks(&if_merging_blocks);
+      is_from_original_loop = [this](uint32_t id) {
+        return loop_->IsInsideLoop(id);
+      };
+    }
+
+    for (auto& specialisation_pair : constant_branch) {
+      ir::Instruction* specialisation_value = specialisation_pair.first;
+      //////////////////////////////////////////////////////////
+      // Step 3: Duplicate |loop_|.
+      //////////////////////////////////////////////////////////
+      LoopUtils::LoopCloningResult clone_result;
+
+      ir::Loop* cloned_loop =
+          loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
+      specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
+
+      ////////////////////////////////////
+      // Step 4: Specialize the loop.   //
+      ////////////////////////////////////
+
+      {
+        std::unordered_set<uint32_t> dead_blocks;
+        std::unordered_set<uint32_t> unreachable_merges;
+        SimplifyLoop(
+            ir::make_range(
+                ir::UptrVectorIterator<ir::BasicBlock>(
+                    &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()),
+                ir::UptrVectorIterator<ir::BasicBlock>(
+                    &clone_result.cloned_bb_, clone_result.cloned_bb_.end())),
+            cloned_loop, condition, specialisation_value, &dead_blocks);
+
+        // We tagged dead blocks, create the loop before we invalidate any basic
+        // block.
+        cloned_loop =
+            CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges);
+        CleanUpCFG(
+            ir::UptrVectorIterator<ir::BasicBlock>(
+                &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()),
+            dead_blocks, unreachable_merges);
+
+        ///////////////////////////////////////////////////////////
+        // Step 5: Connect convergent edges to the landing pads. //
+        ///////////////////////////////////////////////////////////
+
+        for (uint32_t merge_bb_id : if_merging_blocks) {
+          ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id);
+          // We are in LCSSA so we only care about phi instructions.
+          merge->ForEachPhiInst([is_from_original_loop, &dead_blocks,
+                                 &clone_result](ir::Instruction* phi) {
+            uint32_t num_in_operands = phi->NumInOperands();
+            for (uint32_t i = 0; i < num_in_operands; i += 2) {
+              uint32_t pred = phi->GetSingleWordInOperand(i + 1);
+              if (is_from_original_loop(pred)) {
+                pred = clone_result.value_map_.at(pred);
+                if (!dead_blocks.count(pred)) {
+                  uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
+                  // Not all the incoming value are coming from the loop.
+                  ValueMapTy::iterator new_value =
+                      clone_result.value_map_.find(incoming_value_id);
+                  if (new_value != clone_result.value_map_.end()) {
+                    incoming_value_id = new_value->second;
+                  }
+                  phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
+                  phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
+                }
+              }
+            }
+          });
+        }
+      }
+      function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
+                                clone_result.cloned_bb_.end(),
+                                ++FindBasicBlockPosition(if_block));
+    }
+
+    // Same as above but specialize the existing loop
+    {
+      std::unordered_set<uint32_t> dead_blocks;
+      std::unordered_set<uint32_t> unreachable_merges;
+      SimplifyLoop(ir::make_range(function_->begin(), function_->end()), loop_,
+                   condition, original_loop_constant_value, &dead_blocks);
+
+      for (uint32_t merge_bb_id : if_merging_blocks) {
+        ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id);
+        // LCSSA, so we only care about phi instructions.
+        // If we the phi is reduced to a single incoming branch, do not
+        // propagate it to preserve LCSSA.
+        PatchPhis(merge, dead_blocks, true);
+      }
+      if (if_merge_block) {
+        bool has_live_pred = false;
+        for (uint32_t pid : cfg.preds(if_merge_block->id())) {
+          if (!dead_blocks.count(pid)) {
+            has_live_pred = true;
+            break;
+          }
+        }
+        if (!has_live_pred) unreachable_merges.insert(if_merge_block->id());
+      }
+      original_loop_target = loop_->GetPreHeaderBlock();
+      // We tagged dead blocks, prune the loop descriptor from any dead loops.
+      // After this call, |loop_| can be nullptr (i.e. the unswitch killed this
+      // loop).
+      loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges);
+
+      CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges);
+    }
+
+    /////////////////////////////////////
+    // Finally: connect the new loops. //
+    /////////////////////////////////////
+
+    // Delete the old jump
+    context_->KillInst(&*if_block->tail());
+    opt::InstructionBuilder builder(context_, if_block);
+    if (iv_opcode == SpvOpBranchConditional) {
+      assert(constant_branch.size() == 1);
+      builder.AddConditionalBranch(
+          condition->result_id(), original_loop_target->id(),
+          constant_branch[0].second->id(),
+          if_merge_block ? if_merge_block->id() : kInvalidId);
+    } else {
+      std::vector<std::pair<std::vector<uint32_t>, uint32_t>> targets;
+      for (auto& t : constant_branch) {
+        targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
+      }
+
+      builder.AddSwitch(condition->result_id(), original_loop_target->id(),
+                        targets,
+                        if_merge_block ? if_merge_block->id() : kInvalidId);
+    }
+
+    switch_block_ = nullptr;
+    ordered_loop_blocks_.clear();
+
+    context_->InvalidateAnalysesExceptFor(
+        ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+  }
+
+  // Returns true if the unswitch killed the original |loop_|.
+  bool WasLoopKilled() const { return loop_ == nullptr; }
+
+ private:
+  using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
+  using BlockMapTy = std::unordered_map<uint32_t, ir::BasicBlock*>;
+
+  ir::Function* function_;
+  ir::Loop* loop_;
+  ir::LoopDescriptor& loop_desc_;
+  ir::IRContext* context_;
+
+  ir::BasicBlock* switch_block_;
+  // Map between instructions and if they are dynamically uniform.
+  std::unordered_map<uint32_t, bool> dynamically_uniform_;
+  // The loop basic blocks in structured order.
+  std::vector<ir::BasicBlock*> ordered_loop_blocks_;
+
+  // Returns the next usable id for the context.
+  uint32_t TakeNextId() { return context_->TakeNextId(); }
+
+  // Patches |bb|'s phi instruction by removing incoming value from unexisting
+  // or tagged as dead branches.
+  void PatchPhis(ir::BasicBlock* bb,
+                 const std::unordered_set<uint32_t>& dead_blocks,
+                 bool preserve_phi) {
+    ir::CFG& cfg = *context_->cfg();
+
+    std::vector<ir::Instruction*> phi_to_kill;
+    const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id());
+    auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) {
+      return dead_blocks.count(id) ||
+             std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
+    };
+    bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi,
+                        this](ir::Instruction* insn) {
+      uint32_t i = 0;
+      while (i < insn->NumInOperands()) {
+        uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
+        if (is_branch_dead(incoming_id)) {
+          // Remove the incoming block id operand.
+          insn->RemoveInOperand(i + 1);
+          // Remove the definition id operand.
+          insn->RemoveInOperand(i);
+          continue;
+        }
+        i += 2;
+      }
+      // If there is only 1 remaining edge, propagate the value and
+      // kill the instruction.
+      if (insn->NumInOperands() == 2 && !preserve_phi) {
+        phi_to_kill.push_back(insn);
+        context_->ReplaceAllUsesWith(insn->result_id(),
+                                     insn->GetSingleWordInOperand(0));
+      }
+    });
+    for (ir::Instruction* insn : phi_to_kill) {
+      context_->KillInst(insn);
+    }
+  }
+
+  // Removes any block that is tagged as dead, if the block is in
+  // |unreachable_merges| then all block's instructions are replaced by a
+  // OpUnreachable.
+  void CleanUpCFG(ir::UptrVectorIterator<ir::BasicBlock> bb_it,
+                  const std::unordered_set<uint32_t>& dead_blocks,
+                  const std::unordered_set<uint32_t>& unreachable_merges) {
+    ir::CFG& cfg = *context_->cfg();
+
+    while (bb_it != bb_it.End()) {
+      ir::BasicBlock& bb = *bb_it;
+
+      if (unreachable_merges.count(bb.id())) {
+        if (bb.begin() != bb.tail() ||
+            bb.terminator()->opcode() != SpvOpUnreachable) {
+          // Make unreachable, but leave the label.
+          bb.KillAllInsts(false);
+          opt::InstructionBuilder(context_, &bb).AddUnreachable();
+          cfg.RemoveNonExistingEdges(bb.id());
+        }
+        ++bb_it;
+      } else if (dead_blocks.count(bb.id())) {
+        cfg.ForgetBlock(&bb);
+        // Kill this block.
+        bb.KillAllInsts(true);
+        bb_it = bb_it.Erase();
+      } else {
+        cfg.RemoveNonExistingEdges(bb.id());
+        ++bb_it;
+      }
+    }
+  }
+
+  // Return true if |c_inst| is a Boolean constant and set |cond_val| with the
+  // value that |c_inst|
+  bool GetConstCondition(const ir::Instruction* c_inst, bool* cond_val) {
+    bool cond_is_const;
+    switch (c_inst->opcode()) {
+      case SpvOpConstantFalse: {
+        *cond_val = false;
+        cond_is_const = true;
+      } break;
+      case SpvOpConstantTrue: {
+        *cond_val = true;
+        cond_is_const = true;
+      } break;
+      default: { cond_is_const = false; } break;
+    }
+    return cond_is_const;
+  }
+
+  // Simplifies |loop| assuming the instruction |to_version_insn| takes the
+  // value |cst_value|. |block_range| is an iterator range returning the loop
+  // basic blocks in a structured order (dominator first).
+  // The function will ignore basic blocks returned by |block_range| if they
+  // does not belong to the loop.
+  // The set |dead_blocks| will contain all the dead basic blocks.
+  //
+  // Requirements:
+  //   - |loop| must be in the LCSSA form;
+  //   - |cst_value| must be constant or null (to represent the default target
+  //   of an OpSwitch).
+  void SimplifyLoop(
+      ir::IteratorRange<ir::UptrVectorIterator<ir::BasicBlock>> block_range,
+      ir::Loop* loop, ir::Instruction* to_version_insn,
+      ir::Instruction* cst_value, std::unordered_set<uint32_t>* dead_blocks) {
+    ir::CFG& cfg = *context_->cfg();
+    analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+    std::function<bool(uint32_t)> ignore_node;
+    ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
+
+    std::vector<std::pair<ir::Instruction*, uint32_t>> use_list;
+    def_use_mgr->ForEachUse(
+        to_version_insn, [&use_list, &ignore_node, this](
+                             ir::Instruction* inst, uint32_t operand_index) {
+          ir::BasicBlock* bb = context_->get_instr_block(inst);
+
+          if (!bb || ignore_node(bb->id())) {
+            // Out of the loop, the specialization does not apply any more.
+            return;
+          }
+          use_list.emplace_back(inst, operand_index);
+        });
+
+    // First pass: inject the specialized value into the loop (and only the
+    // loop).
+    for (auto use : use_list) {
+      ir::Instruction* inst = use.first;
+      uint32_t operand_index = use.second;
+      ir::BasicBlock* bb = context_->get_instr_block(inst);
+
+      // If it is not a branch, simply inject the value.
+      if (!inst->IsBranch()) {
+        // To also handle switch, cst_value can be nullptr: this case
+        // means that we are looking to branch to the default target of
+        // the switch. We don't actually know its value so we don't touch
+        // it if it not a switch.
+        if (cst_value) {
+          inst->SetOperand(operand_index, {cst_value->result_id()});
+          def_use_mgr->AnalyzeInstUse(inst);
+        }
+      }
+
+      // The user is a branch, kill dead branches.
+      uint32_t live_target = 0;
+      std::unordered_set<uint32_t> dead_branches;
+      switch (inst->opcode()) {
+        case SpvOpBranchConditional: {
+          assert(cst_value && "No constant value to specialize !");
+          bool branch_cond = false;
+          if (GetConstCondition(cst_value, &branch_cond)) {
+            uint32_t true_label =
+                inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
+            uint32_t false_label =
+                inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
+            live_target = branch_cond ? true_label : false_label;
+            uint32_t dead_target = !branch_cond ? true_label : false_label;
+            cfg.RemoveEdge(bb->id(), dead_target);
+          }
+          break;
+        }
+        case SpvOpSwitch: {
+          live_target = inst->GetSingleWordInOperand(1);
+          if (cst_value) {
+            if (!cst_value->IsConstant()) break;
+            const ir::Operand& cst = cst_value->GetInOperand(0);
+            for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) {
+              const ir::Operand& literal = inst->GetInOperand(i);
+              if (literal == cst) {
+                live_target = inst->GetSingleWordInOperand(i + 1);
+                break;
+              }
+            }
+          }
+          for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
+            uint32_t id = inst->GetSingleWordInOperand(i);
+            if (id != live_target) {
+              cfg.RemoveEdge(bb->id(), id);
+            }
+          }
+        }
+        default:
+          break;
+      }
+      if (live_target != 0) {
+        // Check for the presence of the merge block.
+        if (ir::Instruction* merge = bb->GetMergeInst())
+          context_->KillInst(merge);
+        context_->KillInst(&*bb->tail());
+        opt::InstructionBuilder builder(
+            context_, bb,
+            ir::IRContext::kAnalysisDefUse |
+                ir::IRContext::kAnalysisInstrToBlockMapping);
+        builder.AddBranch(live_target);
+      }
+    }
+
+    // Go through the loop basic block and tag all blocks that are obviously
+    // dead.
+    std::unordered_set<uint32_t> visited;
+    for (ir::BasicBlock& bb : block_range) {
+      if (ignore_node(bb.id())) continue;
+      visited.insert(bb.id());
+
+      // Check if this block is dead, if so tag it as dead otherwise patch phi
+      // instructions.
+      bool has_live_pred = false;
+      for (uint32_t pid : cfg.preds(bb.id())) {
+        if (!dead_blocks->count(pid)) {
+          has_live_pred = true;
+          break;
+        }
+      }
+      if (!has_live_pred) {
+        dead_blocks->insert(bb.id());
+        const ir::BasicBlock& cbb = bb;
+        // Patch the phis for any back-edge.
+        cbb.ForEachSuccessorLabel(
+            [dead_blocks, &visited, &cfg, this](uint32_t id) {
+              if (!visited.count(id) || dead_blocks->count(id)) return;
+              ir::BasicBlock* succ = cfg.block(id);
+              PatchPhis(succ, *dead_blocks, false);
+            });
+        continue;
+      }
+      // Update the phi instructions, some incoming branch have/will disappear.
+      PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false);
+    }
+  }
+
+  // Returns true if the header is not reachable or tagged as dead or if we
+  // never loop back.
+  bool IsLoopDead(ir::BasicBlock* header, ir::BasicBlock* latch,
+                  const std::unordered_set<uint32_t>& dead_blocks) {
+    if (!header || dead_blocks.count(header->id())) return true;
+    if (!latch || dead_blocks.count(latch->id())) return true;
+    for (uint32_t pid : context_->cfg()->preds(header->id())) {
+      if (!dead_blocks.count(pid)) {
+        // Seems reachable.
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // Cleans the loop nest under |loop| and reflect changes to the loop
+  // descriptor. This will kill all descriptors that represent dead loops.
+  // If |loop_| is killed, it will be set to nullptr.
+  // Any merge blocks that become unreachable will be added to
+  // |unreachable_merges|.
+  // The function returns the pointer to |loop| or nullptr if the loop was
+  // killed.
+  ir::Loop* CleanLoopNest(ir::Loop* loop,
+                          const std::unordered_set<uint32_t>& dead_blocks,
+                          std::unordered_set<uint32_t>* unreachable_merges) {
+    // This represent the pair of dead loop and nearest alive parent (nullptr if
+    // no parent).
+    std::unordered_map<ir::Loop*, ir::Loop*> dead_loops;
+    auto get_parent = [&dead_loops](ir::Loop* l) -> ir::Loop* {
+      std::unordered_map<ir::Loop*, ir::Loop*>::iterator it =
+          dead_loops.find(l);
+      if (it != dead_loops.end()) return it->second;
+      return nullptr;
+    };
+
+    bool is_main_loop_dead =
+        IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks);
+    if (is_main_loop_dead) {
+      if (ir::Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) {
+        context_->KillInst(merge);
+      }
+      dead_loops[loop] = loop->GetParent();
+    } else
+      dead_loops[loop] = loop;
+    // For each loop, check if we killed it. If we did, find a suitable parent
+    // for its children.
+    for (ir::Loop& sub_loop :
+         ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop),
+                        opt::TreeDFIterator<ir::Loop>())) {
+      if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(),
+                     dead_blocks)) {
+        if (ir::Instruction* merge =
+                sub_loop.GetHeaderBlock()->GetLoopMergeInst()) {
+          context_->KillInst(merge);
+        }
+        dead_loops[&sub_loop] = get_parent(&sub_loop);
+      } else {
+        // The loop is alive, check if its merge block is dead, if it is, tag it
+        // as required.
+        if (sub_loop.GetMergeBlock()) {
+          uint32_t merge_id = sub_loop.GetMergeBlock()->id();
+          if (dead_blocks.count(merge_id)) {
+            unreachable_merges->insert(sub_loop.GetMergeBlock()->id());
+          }
+        }
+      }
+    }
+    if (!is_main_loop_dead) dead_loops.erase(loop);
+
+    // Remove dead blocks from live loops.
+    for (uint32_t bb_id : dead_blocks) {
+      ir::Loop* l = loop_desc_[bb_id];
+      if (l) {
+        l->RemoveBasicBlock(bb_id);
+        loop_desc_.ForgetBasicBlock(bb_id);
+      }
+    }
+
+    std::for_each(
+        dead_loops.begin(), dead_loops.end(),
+        [&loop, this](
+            std::unordered_map<ir::Loop*, ir::Loop*>::iterator::reference it) {
+          if (it.first == loop) loop = nullptr;
+          loop_desc_.RemoveLoop(it.first);
+        });
+
+    return loop;
+  }
+
+  // Returns true if |var| is dynamically uniform.
+  // Note: this is currently approximated as uniform.
+  bool IsDynamicallyUniform(ir::Instruction* var, const ir::BasicBlock* entry,
+                            const DominatorTree& post_dom_tree) {
+    assert(post_dom_tree.IsPostDominator());
+    analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+    auto it = dynamically_uniform_.find(var->result_id());
+
+    if (it != dynamically_uniform_.end()) return it->second;
+
+    analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
+
+    bool& is_uniform = dynamically_uniform_[var->result_id()];
+    is_uniform = false;
+
+    dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform,
+                                 [&is_uniform](const ir::Instruction&) {
+                                   is_uniform = true;
+                                   return false;
+                                 });
+    if (is_uniform) {
+      return is_uniform;
+    }
+
+    ir::BasicBlock* parent = context_->get_instr_block(var);
+    if (!parent) {
+      return is_uniform = true;
+    }
+
+    if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
+      return is_uniform = false;
+    }
+    if (var->opcode() == SpvOpLoad) {
+      const uint32_t PtrTypeId =
+          def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
+      const ir::Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
+      uint32_t storage_class =
+          PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx);
+      if (storage_class != SpvStorageClassUniform &&
+          storage_class != SpvStorageClassUniformConstant) {
+        return is_uniform = false;
+      }
+    } else {
+      if (!context_->IsCombinatorInstruction(var)) {
+        return is_uniform = false;
+      }
+    }
+
+    return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
+                                            this](const uint32_t* id) {
+      return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
+                                  entry, post_dom_tree);
+    });
+  }
+
+  // Returns true if |insn| is constant and dynamically uniform within the loop.
+  bool IsConditionLoopInvariant(ir::Instruction* insn) {
+    assert(insn->IsBranch());
+    assert(insn->opcode() != SpvOpBranch);
+    analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+    ir::Instruction* condition =
+        def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
+    return !loop_->IsInsideLoop(condition) &&
+           IsDynamicallyUniform(
+               condition, function_->entry().get(),
+               context_->GetPostDominatorAnalysis(function_, *context_->cfg())
+                   ->GetDomTree());
+  }
+};
+
+}  // namespace
+
+Pass::Status LoopUnswitchPass::Process(ir::IRContext* c) {
+  InitializeProcessing(c);
+
+  bool modified = false;
+  ir::Module* module = c->module();
+
+  // Process each function in the module
+  for (ir::Function& f : *module) {
+    modified |= ProcessFunction(&f);
+  }
+
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+bool LoopUnswitchPass::ProcessFunction(ir::Function* f) {
+  bool modified = false;
+  std::unordered_set<ir::Loop*> processed_loop;
+
+  ir::LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
+
+  bool loop_changed = true;
+  while (loop_changed) {
+    loop_changed = false;
+    for (ir::Loop& loop :
+         ir::make_range(++opt::TreeDFIterator<ir::Loop>(
+                            loop_descriptor.GetDummyRootLoop()),
+                        opt::TreeDFIterator<ir::Loop>())) {
+      if (processed_loop.count(&loop)) continue;
+      processed_loop.insert(&loop);
+
+      LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
+      while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) {
+        if (!loop.IsLCSSA()) {
+          LoopUtils(context(), &loop).MakeLoopClosedSSA();
+        }
+        modified = true;
+        loop_changed = true;
+        unswitcher.PerformUnswitch();
+      }
+      if (loop_changed) break;
+    }
+  }
+
+  return modified;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/loop_unswitch_pass.h b/source/opt/loop_unswitch_pass.h
new file mode 100644 (file)
index 0000000..dbe5814
--- /dev/null
@@ -0,0 +1,43 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_
+#define LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_
+
+#include "opt/loop_descriptor.h"
+#include "opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Implements the loop unswitch optimization.
+// The loop unswitch hoists invariant "if" statements if the conditions are
+// constant within the loop and clones the loop for each branch.
+class LoopUnswitchPass : public Pass {
+ public:
+  const char* name() const override { return "loop-unswitch"; }
+
+  // Processes the given |module|. Returns Status::Failure if errors occur when
+  // processing. Returns the corresponding Status::Success if processing is
+  // succesful to indicate whether changes have been made to the modue.
+  Pass::Status Process(ir::IRContext* context) override;
+
+ private:
+  bool ProcessFunction(ir::Function* f);
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // !LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_
index 6c2a15f..8532679 100644 (file)
@@ -18,6 +18,7 @@
 #include <unordered_set>
 #include <vector>
 
+#include "cfa.h"
 #include "opt/cfg.h"
 #include "opt/ir_builder.h"
 #include "opt/ir_context.h"
@@ -481,5 +482,114 @@ void LoopUtils::MakeLoopClosedSSA() {
       ir::IRContext::Analysis::kAnalysisLoopAnalysis);
 }
 
+ir::Loop* LoopUtils::CloneLoop(
+    LoopCloningResult* cloning_result,
+    const std::vector<ir::BasicBlock*>& ordered_loop_blocks) const {
+  analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+  std::unique_ptr<ir::Loop> new_loop = MakeUnique<ir::Loop>(context_);
+  if (loop_->HasParent()) new_loop->SetParent(loop_->GetParent());
+
+  ir::CFG& cfg = *context_->cfg();
+
+  // Clone and place blocks in a SPIR-V compliant order (dominators first).
+  for (ir::BasicBlock* old_bb : ordered_loop_blocks) {
+    // For each basic block in the loop, we clone it and register the mapping
+    // between old and new ids.
+    ir::BasicBlock* new_bb = old_bb->Clone(context_);
+    new_bb->SetParent(&function_);
+    new_bb->GetLabelInst()->SetResultId(context_->TakeNextId());
+    def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst());
+    context_->set_instr_block(new_bb->GetLabelInst(), new_bb);
+    cloning_result->cloned_bb_.emplace_back(new_bb);
+
+    cloning_result->old_to_new_bb_[old_bb->id()] = new_bb;
+    cloning_result->new_to_old_bb_[new_bb->id()] = old_bb;
+    cloning_result->value_map_[old_bb->id()] = new_bb->id();
+
+    if (loop_->IsInsideLoop(old_bb)) new_loop->AddBasicBlock(new_bb);
+
+    for (auto& inst : *new_bb) {
+      if (inst.HasResultId()) {
+        uint32_t old_result_id = inst.result_id();
+        inst.SetResultId(context_->TakeNextId());
+        cloning_result->value_map_[old_result_id] = inst.result_id();
+
+        // Only look at the defs for now, uses are not updated yet.
+        def_use_mgr->AnalyzeInstDef(&inst);
+      }
+    }
+  }
+
+  // All instructions (including all labels) have been cloned,
+  // remap instruction operands id with the new ones.
+  for (std::unique_ptr<ir::BasicBlock>& bb_ref : cloning_result->cloned_bb_) {
+    ir::BasicBlock* bb = bb_ref.get();
+
+    for (ir::Instruction& insn : *bb) {
+      insn.ForEachInId([cloning_result](uint32_t* old_id) {
+        // If the operand is defined in the loop, remap the id.
+        auto id_it = cloning_result->value_map_.find(*old_id);
+        if (id_it != cloning_result->value_map_.end()) {
+          *old_id = id_it->second;
+        }
+      });
+      // Only look at what the instruction uses. All defs are register, so all
+      // should be fine now.
+      def_use_mgr->AnalyzeInstUse(&insn);
+      context_->set_instr_block(&insn, bb);
+    }
+    cfg.RegisterBlock(bb);
+  }
+
+  PopulateLoopNest(new_loop.get(), *cloning_result);
+
+  return new_loop.release();
+}
+
+void LoopUtils::PopulateLoopNest(
+    ir::Loop* new_loop, const LoopCloningResult& cloning_result) const {
+  std::unordered_map<ir::Loop*, ir::Loop*> loop_mapping;
+  loop_mapping[loop_] = new_loop;
+
+  if (loop_->HasParent()) loop_->GetParent()->AddNestedLoop(new_loop);
+  PopulateLoopDesc(new_loop, loop_, cloning_result);
+
+  for (ir::Loop& sub_loop :
+       ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop_),
+                      opt::TreeDFIterator<ir::Loop>())) {
+    ir::Loop* cloned = new ir::Loop(context_);
+    if (ir::Loop* parent = loop_mapping[sub_loop.GetParent()])
+      parent->AddNestedLoop(cloned);
+    loop_mapping[&sub_loop] = cloned;
+    PopulateLoopDesc(cloned, &sub_loop, cloning_result);
+  }
+
+  loop_desc_->AddLoopNest(std::unique_ptr<ir::Loop>(new_loop));
+}
+
+// Populates |new_loop| descriptor according to |old_loop|'s one.
+void LoopUtils::PopulateLoopDesc(
+    ir::Loop* new_loop, ir::Loop* old_loop,
+    const LoopCloningResult& cloning_result) const {
+  for (uint32_t bb_id : old_loop->GetBlocks()) {
+    ir::BasicBlock* bb = cloning_result.old_to_new_bb_.at(bb_id);
+    new_loop->AddBasicBlock(bb);
+  }
+  new_loop->SetHeaderBlock(
+      cloning_result.old_to_new_bb_.at(old_loop->GetHeaderBlock()->id()));
+  if (old_loop->GetLatchBlock())
+    new_loop->SetLatchBlock(
+        cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id()));
+  if (old_loop->GetMergeBlock()) {
+    ir::BasicBlock* bb =
+        cloning_result.old_to_new_bb_.at(old_loop->GetMergeBlock()->id());
+    new_loop->SetMergeBlock(bb);
+  }
+  if (old_loop->GetPreHeaderBlock())
+    new_loop->SetPreHeaderBlock(
+        cloning_result.old_to_new_bb_.at(old_loop->GetPreHeaderBlock()->id()));
+}
+
 }  // namespace opt
 }  // namespace spvtools
index 89e6936..0e77bb6 100644 (file)
 #include <list>
 #include <memory>
 #include <vector>
+#include "opt/ir_context.h"
 #include "opt/loop_descriptor.h"
 
 namespace spvtools {
 
-namespace ir {
-class Loop;
-class IRContext;
-}  // namespace ir
-
 namespace opt {
 
 // LoopUtils is used to encapsulte loop optimizations and from the passes which
@@ -33,8 +29,25 @@ namespace opt {
 // or through a pass which is using this.
 class LoopUtils {
  public:
+  // Holds a auxiliary results of the loop cloning procedure.
+  struct LoopCloningResult {
+    using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
+    using BlockMapTy = std::unordered_map<uint32_t, ir::BasicBlock*>;
+
+    // Mapping between the original loop ids and the new one.
+    ValueMapTy value_map_;
+    // Mapping between original loop blocks to the cloned one.
+    BlockMapTy old_to_new_bb_;
+    // Mapping between the cloned loop blocks to original one.
+    BlockMapTy new_to_old_bb_;
+    // List of cloned basic block.
+    std::vector<std::unique_ptr<ir::BasicBlock>> cloned_bb_;
+  };
+
   LoopUtils(ir::IRContext* context, ir::Loop* loop)
       : context_(context),
+        loop_desc_(
+            context->GetLoopDescriptor(loop->GetHeaderBlock()->GetParent())),
         loop_(loop),
         function_(*loop_->GetHeaderBlock()->GetParent()) {}
 
@@ -72,6 +85,17 @@ class LoopUtils {
   // Preserves: CFG, def/use and instruction to block mapping.
   void CreateLoopDedicatedExits();
 
+  // Clone |loop_| and remap its instructions. Newly created blocks
+  // will be added to the |cloning_result.cloned_bb_| list, correctly ordered to
+  // be inserted into a function. If the loop is structured, the merge construct
+  // will also be cloned. The function preserves the def/use, cfg and instr to
+  // block analyses.
+  // The cloned loop nest will be added to the loop descriptor and will have
+  // owner ship.
+  ir::Loop* CloneLoop(
+      LoopCloningResult* cloning_result,
+      const std::vector<ir::BasicBlock*>& ordered_loop_blocks) const;
+
   // Perfom a partial unroll of |loop| by given |factor|. This will copy the
   // body of the loop |factor| times. So a |factor| of one would give a new loop
   // with the original body plus one unrolled copy body.
@@ -103,8 +127,17 @@ class LoopUtils {
 
  private:
   ir::IRContext* context_;
+  ir::LoopDescriptor* loop_desc_;
   ir::Loop* loop_;
   ir::Function& function_;
+
+  // Populates the loop nest of |new_loop| according to |loop_| nest.
+  void PopulateLoopNest(ir::Loop* new_loop,
+                        const LoopCloningResult& cloning_result) const;
+
+  // Populates |new_loop| descriptor according to |old_loop|'s one.
+  void PopulateLoopDesc(ir::Loop* new_loop, ir::Loop* old_loop,
+                        const LoopCloningResult& cloning_result) const;
 };
 
 }  // namespace opt
index 3996477..86b84a0 100644 (file)
@@ -137,11 +137,7 @@ bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
 }
 
 void MemPass::KillAllInsts(ir::BasicBlock* bp, bool killLabel) {
-  bp->ForEachInst([this, killLabel](ir::Instruction* ip) {
-    if (killLabel || ip->opcode() != SpvOpLabel) {
-      context()->KillInst(ip);
-    }
-  });
+  bp->KillAllInsts(killLabel);
 }
 
 bool MemPass::HasLoads(uint32_t varId) const {
index dced5db..c52e643 100644 (file)
@@ -356,6 +356,11 @@ Optimizer::PassToken CreateLoopInvariantCodeMotionPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::LICMPass>());
 }
 
+Optimizer::PassToken CreateLoopUnswitchPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::LoopUnswitchPass>());
+}
+
 Optimizer::PassToken CreateRedundancyEliminationPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::RedundancyEliminationPass>());
index 9fb98aa..f0fb289 100644 (file)
@@ -42,6 +42,7 @@
 #include "local_single_store_elim_pass.h"
 #include "local_ssa_elim_pass.h"
 #include "loop_unroller.h"
+#include "loop_unswitch_pass.h"
 #include "merge_return_pass.h"
 #include "null_pass.h"
 #include "private_to_local_pass.h"
index a9cb499..3a1144d 100644 (file)
@@ -79,3 +79,8 @@ add_spvtools_unittest(TARGET loop_unroll_assumtion_checks
     LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET unswitch_test
+    SRCS ../function_utils.h
+        unswitch.cpp
+    LIBS SPIRV-Tools-opt
+)
index 90e585e..f53ad05 100644 (file)
@@ -203,4 +203,98 @@ TEST_F(PassClassTest, LoopWithNoPreHeader) {
   EXPECT_NE(loop->GetOrCreatePreHeaderBlock(), nullptr);
 }
 
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  int i = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++) {
+    if (cond) {
+      return;
+    }
+    else {
+      return;
+    }
+  }
+  bool cond2 = i == 9;
+}
+*/
+TEST_F(PassClassTest, NoLoop) {
+  const std::string text = R"(; SPIR-V
+; Version: 1.0
+; Generator: Khronos Glslang Reference Front End; 3
+; Bound: 47
+; Schema: 0
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main" %16
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %4 "main"
+               OpName %16 "c"
+               OpDecorate %16 Location 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 0
+         %10 = OpTypeBool
+         %11 = OpTypePointer Function %10
+         %13 = OpTypeFloat 32
+         %14 = OpTypeVector %13 4
+         %15 = OpTypePointer Input %14
+         %16 = OpVariable %15 Input
+         %17 = OpTypeInt 32 0
+         %18 = OpConstant %17 0
+         %19 = OpTypePointer Input %13
+         %22 = OpConstant %13 0
+         %30 = OpConstant %6 10
+         %39 = OpConstant %6 1
+         %46 = OpUndef %6
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %20 = OpAccessChain %19 %16 %18
+         %21 = OpLoad %13 %20
+         %23 = OpFOrdEqual %10 %21 %22
+               OpBranch %24
+         %24 = OpLabel
+         %45 = OpPhi %6 %9 %5 %40 %27
+               OpLoopMerge %26 %27 None
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpSLessThan %10 %45 %30
+               OpBranchConditional %31 %25 %26
+         %25 = OpLabel
+               OpSelectionMerge %34 None
+               OpBranchConditional %23 %33 %36
+         %33 = OpLabel
+               OpReturn
+         %36 = OpLabel
+               OpReturn
+         %34 = OpLabel
+               OpBranch %27
+         %27 = OpLabel
+         %40 = OpIAdd %6 %46 %39
+               OpBranch %24
+         %26 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 4);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 0u);
+}
+
 }  // namespace
diff --git a/test/opt/loop_optimizations/unswitch.cpp b/test/opt/loop_optimizations/unswitch.cpp
new file mode 100644 (file)
index 0000000..d5c5209
--- /dev/null
@@ -0,0 +1,914 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
+#include "../pass_fixture.h"
+
+namespace {
+
+using namespace spvtools;
+
+using UnswitchTest = PassTest<::testing::Test>;
+
+#ifdef SPIRV_EFFCEE
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 450 core
+uniform vec4 c;
+void main() {
+  int i = 0;
+  int j = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++, j++) {
+    if (cond) {
+      i++;
+    }
+    else {
+      j++;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, SimpleUnswitch) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_j:%\w+]] [[continue]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=1 and j+=2.
+; CHECK: [[phi_j:%\w+]] = OpIAdd %int [[phi_j]] %int_1
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_j]] = OpIAdd %int [[phi_j]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_j:%\w+]] [[continue]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=2 and j+=1.
+; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_j]] = OpIAdd %int [[phi_j]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginLowerLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %c DescriptorSet 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_UniformConstant_v4float = OpTypePointer UniformConstant %v4float
+          %c = OpVariable %_ptr_UniformConstant_v4float UniformConstant
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_UniformConstant_float = OpTypePointer UniformConstant %float
+    %float_0 = OpConstant %float 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %21 = OpAccessChain %_ptr_UniformConstant_float %c %uint_0
+         %22 = OpLoad %float %21
+         %24 = OpFOrdEqual %bool %22 %float_0
+               OpBranch %25
+         %25 = OpLabel
+         %46 = OpPhi %int %int_0 %5 %43 %28
+         %47 = OpPhi %int %int_0 %5 %45 %28
+               OpLoopMerge %27 %28 None
+               OpBranch %29
+         %29 = OpLabel
+         %32 = OpSLessThan %bool %46 %int_10
+               OpBranchConditional %32 %26 %27
+         %26 = OpLabel
+               OpSelectionMerge %35 None
+               OpBranchConditional %24 %34 %39
+         %34 = OpLabel
+         %38 = OpIAdd %int %46 %int_1
+               OpBranch %35
+         %39 = OpLabel
+         %41 = OpIAdd %int %47 %int_1
+               OpBranch %35
+         %35 = OpLabel
+         %48 = OpPhi %int %38 %34 %46 %39
+         %49 = OpPhi %int %47 %34 %41 %39
+               OpBranch %28
+         %28 = OpLabel
+         %43 = OpIAdd %int %48 %int_1
+         %45 = OpIAdd %int %49 %int_1
+               OpBranch %25
+         %27 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  int i = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++) {
+    if (cond) {
+      i++;
+    }
+    else {
+      return;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, UnswitchExit) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK: OpReturn
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=2.
+; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %23 Uniform
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+    %float_0 = OpConstant %float 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %20 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %21 = OpLoad %float %20
+         %23 = OpFOrdEqual %bool %21 %float_0
+               OpBranch %24
+         %24 = OpLabel
+         %42 = OpPhi %int %int_0 %5 %41 %27
+               OpLoopMerge %26 %27 None
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpSLessThan %bool %42 %int_10
+               OpBranchConditional %31 %25 %26
+         %25 = OpLabel
+               OpSelectionMerge %34 None
+               OpBranchConditional %23 %33 %38
+         %33 = OpLabel
+         %37 = OpIAdd %int %42 %int_1
+               OpBranch %34
+         %38 = OpLabel
+               OpReturn
+         %34 = OpLabel
+               OpBranch %27
+         %27 = OpLabel
+         %41 = OpIAdd %int %37 %int_1
+               OpBranch %24
+         %26 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  int i = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++) {
+    if (cond) {
+      continue;
+    }
+    else {
+      i++;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, UnswitchContinue) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=2.
+; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=1.
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %23 Uniform
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+    %float_0 = OpConstant %float 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %20 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %21 = OpLoad %float %20
+         %23 = OpFOrdEqual %bool %21 %float_0
+               OpBranch %24
+         %24 = OpLabel
+         %42 = OpPhi %int %int_0 %5 %41 %27
+               OpLoopMerge %26 %27 None
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpSLessThan %bool %42 %int_10
+               OpBranchConditional %31 %25 %26
+         %25 = OpLabel
+               OpSelectionMerge %34 None
+               OpBranchConditional %23 %33 %36
+         %33 = OpLabel
+               OpBranch %27
+         %36 = OpLabel
+         %39 = OpIAdd %int %42 %int_1
+               OpBranch %34
+         %34 = OpLabel
+               OpBranch %27
+         %27 = OpLabel
+         %43 = OpPhi %int %42 %33 %39 %34
+         %41 = OpIAdd %int %43 %int_1
+               OpBranch %24
+         %26 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  int i = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++) {
+    if (cond) {
+      i++;
+    }
+    else {
+      break;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, UnswitchKillLoop) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK: OpBranch [[if_merge]]
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=2.
+; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %23 Uniform
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+    %float_0 = OpConstant %float 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %20 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %21 = OpLoad %float %20
+         %23 = OpFOrdEqual %bool %21 %float_0
+               OpBranch %24
+         %24 = OpLabel
+         %42 = OpPhi %int %int_0 %5 %41 %27
+               OpLoopMerge %26 %27 None
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpSLessThan %bool %42 %int_10
+               OpBranchConditional %31 %25 %26
+         %25 = OpLabel
+               OpSelectionMerge %34 None
+               OpBranchConditional %23 %33 %38
+         %33 = OpLabel
+         %37 = OpIAdd %int %42 %int_1
+               OpBranch %34
+         %38 = OpLabel
+               OpBranch %26
+         %34 = OpLabel
+               OpBranch %27
+         %27 = OpLabel
+         %41 = OpIAdd %int %37 %int_1
+               OpBranch %24
+         %26 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  int i = 0;
+  int cond = int(c[0]);
+  for (; i < 10; i++) {
+    switch (cond) {
+      case 0:
+        return;
+      case 1:
+        discard;
+      case 2:
+        break;
+      default:
+        break;
+    }
+  }
+  bool cond2 = i == 9;
+}
+*/
+TEST_F(UnswitchTest, UnswitchSwitch) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpConvertFToS
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpSwitch [[cst_cond]] [[default:%\w+]] 0 [[loop_0:%\w+]] 1 [[loop_1:%\w+]] 2 [[loop_2:%\w+]]
+
+; Loop specialized for 2.
+; CHECK: [[loop_2]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_2]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=1.
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: OpBranch [[loop]]
+
+; Loop specialized for 1.
+; CHECK: [[loop_1]] = OpLabel
+; CHECK: OpKill
+
+; Loop specialized for 0.
+; CHECK: [[loop_0]] = OpLabel
+; CHECK: OpReturn
+
+; Loop specialized for the default case.
+; CHECK: [[default]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[default]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have i+=1.
+; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK: OpBranch [[loop]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %20 Uniform
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+     %int_10 = OpConstant %int 10
+       %bool = OpTypeBool
+      %int_1 = OpConstant %int 1
+%_ptr_Function_bool = OpTypePointer Function %bool
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %18 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %19 = OpLoad %float %18
+         %20 = OpConvertFToS %int %19
+               OpBranch %21
+         %21 = OpLabel
+         %49 = OpPhi %int %int_0 %5 %43 %24
+               OpLoopMerge %23 %24 None
+               OpBranch %25
+         %25 = OpLabel
+         %29 = OpSLessThan %bool %49 %int_10
+               OpBranchConditional %29 %22 %23
+         %22 = OpLabel
+               OpSelectionMerge %35 None
+               OpSwitch %20 %34 0 %31 1 %32 2 %33
+         %34 = OpLabel
+               OpBranch %35
+         %31 = OpLabel
+               OpReturn
+         %32 = OpLabel
+               OpKill
+         %33 = OpLabel
+               OpBranch %35
+         %35 = OpLabel
+               OpBranch %24
+         %24 = OpLabel
+         %43 = OpIAdd %int %49 %int_1
+               OpBranch %21
+         %23 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 440 core
+layout(location = 0)in vec4 c;
+void main() {
+  int i = 0;
+  int j = 0;
+  int k = 0;
+  bool cond = c[0] == 0;
+  for (; i < 10; i++) {
+    for (; j < 10; j++) {
+      if (cond) {
+        i++;
+      } else {
+        j++;
+      }
+    }
+  }
+  for (; k < 10; k++) {
+    if (cond) {
+      k++;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, UnSwitchNested) {
+  const std::string text = R"(
+; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false, one loop is killed, j won't change anymore.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_j:%\w+]] [[continue]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1
+; CHECK-NEXT: OpBranch [[loop]]
+; CHECK: OpReturn
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_j:%\w+]] [[continue]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[pre_loop_inner:%\w+]] [[merge]]
+
+; CHECK: [[pre_loop_inner]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop_inner:%\w+]]
+; CHECK-NEXT: [[loop_inner]] = OpLabel
+; CHECK-NEXT: [[phi2_i:%\w+]] = OpPhi %int [[phi_i]] [[pre_loop_inner]] [[iv2_i:%\w+]] [[continue2:%\w+]]
+; CHECK-NEXT: [[phi2_j:%\w+]] = OpPhi %int [[phi_j]] [[pre_loop_inner]] [[iv2_j:%\w+]] [[continue2]]
+; CHECK-NEXT: OpLoopMerge [[merge2:%\w+]] [[continue2]] None
+
+; CHECK: OpBranch [[continue2]]
+; CHECK: [[merge2]] = OpLabel
+; CHECK: OpBranch [[continue]]
+; CHECK: [[merge]] = OpLabel
+
+; Unswitched double nested loop is done. Test the single remaining one.
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None
+; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]]
+
+; Loop specialized for false.
+; CHECK: [[loop_f]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_k:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_k:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_k]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have k+=1
+; CHECK: [[iv_k]] = OpIAdd %int [[phi_k]] %int_1
+; CHECK: OpBranch [[loop]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; Loop specialized for true.
+; CHECK: [[loop_t]] = OpLabel
+; CHECK-NEXT: OpBranch [[loop:%\w+]]
+; CHECK: [[loop]] = OpLabel
+; CHECK-NEXT: [[phi_k:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_k:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_k]] {{%\w+}}
+; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]]
+; Check that we have k+=2.
+; CHECK: [[tmp_k:%\w+]] = OpIAdd %int [[phi_k]] %int_1
+; CHECK: [[iv_k]] = OpIAdd %int [[tmp_k]] %int_1
+; CHECK: OpBranch [[loop]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: OpBranch [[if_merge]]
+
+; CHECK: [[if_merge]] = OpLabel
+; CHECK-NEXT: OpReturn
+
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 440
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+               OpDecorate %25 Uniform
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+    %float_0 = OpConstant %float 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %22 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %23 = OpLoad %float %22
+         %25 = OpFOrdEqual %bool %23 %float_0
+               OpBranch %26
+         %26 = OpLabel
+         %67 = OpPhi %int %int_0 %5 %52 %29
+         %68 = OpPhi %int %int_0 %5 %70 %29
+               OpLoopMerge %28 %29 None
+               OpBranch %30
+         %30 = OpLabel
+         %33 = OpSLessThan %bool %67 %int_10
+               OpBranchConditional %33 %27 %28
+         %27 = OpLabel
+               OpBranch %34
+         %34 = OpLabel
+         %69 = OpPhi %int %67 %27 %46 %37
+         %70 = OpPhi %int %68 %27 %50 %37
+               OpLoopMerge %36 %37 None
+               OpBranch %38
+         %38 = OpLabel
+         %40 = OpSLessThan %bool %70 %int_10
+               OpBranchConditional %40 %35 %36
+         %35 = OpLabel
+               OpSelectionMerge %43 None
+               OpBranchConditional %25 %42 %47
+         %42 = OpLabel
+         %46 = OpIAdd %int %69 %int_1
+               OpBranch %43
+         %47 = OpLabel
+               OpReturn
+         %43 = OpLabel
+               OpBranch %37
+         %37 = OpLabel
+         %50 = OpIAdd %int %70 %int_1
+               OpBranch %34
+         %36 = OpLabel
+               OpBranch %29
+         %29 = OpLabel
+         %52 = OpIAdd %int %69 %int_1
+               OpBranch %26
+         %28 = OpLabel
+               OpBranch %53
+         %53 = OpLabel
+         %71 = OpPhi %int %int_0 %28 %66 %56
+               OpLoopMerge %55 %56 None
+               OpBranch %57
+         %57 = OpLabel
+         %59 = OpSLessThan %bool %71 %int_10
+               OpBranchConditional %59 %54 %55
+         %54 = OpLabel
+               OpSelectionMerge %62 None
+               OpBranchConditional %25 %61 %62
+         %61 = OpLabel
+         %64 = OpIAdd %int %71 %int_1
+               OpBranch %62
+         %62 = OpLabel
+         %72 = OpPhi %int %71 %54 %64 %61
+               OpBranch %56
+         %56 = OpLabel
+         %66 = OpIAdd %int %72 %int_1
+               OpBranch %53
+         %55 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<opt::LoopUnswitchPass>(text, true);
+}
+#endif  // SPIRV_EFFCEE
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+in vec4 c;
+void main() {
+  bool cond = false;
+  if (c[0] == 0) {
+     cond = c[1] == 0;
+  } else {
+     cond = c[2] == 0;
+  }
+  for (int i = 0; i < 10; i++) {
+    if (cond) {
+      i++;
+    }
+  }
+}
+*/
+TEST_F(UnswitchTest, UnswitchNotUniform) {
+  // Check that the unswitch is not triggered (condition loop invariant but not
+  // uniform)
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %c
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+               OpName %c "c"
+               OpDecorate %c Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+          %c = OpVariable %_ptr_Input_v4float Input
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Input_float = OpTypePointer Input %float
+    %float_0 = OpConstant %float 0
+     %uint_1 = OpConstant %uint 1
+     %uint_2 = OpConstant %uint 2
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_0 = OpConstant %int 0
+     %int_10 = OpConstant %int 10
+      %int_1 = OpConstant %int 1
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %17 = OpAccessChain %_ptr_Input_float %c %uint_0
+         %18 = OpLoad %float %17
+         %20 = OpFOrdEqual %bool %18 %float_0
+               OpSelectionMerge %22 None
+               OpBranchConditional %20 %21 %27
+         %21 = OpLabel
+         %24 = OpAccessChain %_ptr_Input_float %c %uint_1
+         %25 = OpLoad %float %24
+         %26 = OpFOrdEqual %bool %25 %float_0
+               OpBranch %22
+         %27 = OpLabel
+         %29 = OpAccessChain %_ptr_Input_float %c %uint_2
+         %30 = OpLoad %float %29
+         %31 = OpFOrdEqual %bool %30 %float_0
+               OpBranch %22
+         %22 = OpLabel
+         %52 = OpPhi %bool %26 %21 %31 %27
+               OpBranch %36
+         %36 = OpLabel
+         %53 = OpPhi %int %int_0 %22 %51 %39
+               OpLoopMerge %38 %39 None
+               OpBranch %40
+         %40 = OpLabel
+         %43 = OpSLessThan %bool %53 %int_10
+               OpBranchConditional %43 %37 %38
+         %37 = OpLabel
+               OpSelectionMerge %46 None
+               OpBranchConditional %52 %45 %46
+         %45 = OpLabel
+         %49 = OpIAdd %int %53 %int_1
+               OpBranch %46
+         %46 = OpLabel
+         %54 = OpPhi %int %53 %37 %49 %45
+               OpBranch %39
+         %39 = OpLabel
+         %51 = OpIAdd %int %54 %int_1
+               OpBranch %36
+         %38 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  auto result =
+      SinglePassRunAndDisassemble<opt::LoopUnswitchPass>(text, true, false);
+
+  EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
+}
+
+}  // namespace
index 3f8dd88..9c9a271 100644 (file)
@@ -187,6 +187,10 @@ Options (in lexicographical order):
   --local-redundancy-elimination
                Looks for instructions in the same basic block that compute the
                same value, and deletes the redundant ones.
+  --loop-unswitch
+               Hoists loop-invariant conditionals out of loops by duplicating
+               the loop on each branch of the conditional and adjusting each
+               copy of the loop.
   -O
                Optimize for performance. Apply a sequence of transformations
                in an attempt to improve the performance of the generated
@@ -463,6 +467,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         optimizer->RegisterPass(CreateDeadVariableEliminationPass());
       } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
         optimizer->RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
+      } else if (0 == strcmp(cur_arg, "--loop-unswitch")) {
+        optimizer->RegisterPass(CreateLoopUnswitchPass());
       } else if (0 == strcmp(cur_arg, "--scalar-replacement")) {
         optimizer->RegisterPass(CreateScalarReplacementPass());
       } else if (0 == strcmp(cur_arg, "--strength-reduction")) {