Adding early exit versions of several ForEach* methods
authorAlan Baker <alanbaker@google.com>
Fri, 12 Jan 2018 20:05:53 +0000 (15:05 -0500)
committerDavid Neto <dneto@google.com>
Fri, 12 Jan 2018 22:05:09 +0000 (17:05 -0500)
* Looked through code for instances where code would benefit from early
exit
 * Added a corresponding WhileEach* method and updated the code

20 files changed:
source/opt/basic_block.h
source/opt/block_merge_pass.cpp
source/opt/ccp_pass.cpp
source/opt/common_uniform_elim_pass.cpp
source/opt/dead_branch_elim_pass.cpp
source/opt/decoration_manager.cpp
source/opt/decoration_manager.h
source/opt/def_use_manager.cpp
source/opt/def_use_manager.h
source/opt/inline_opaque_pass.cpp
source/opt/instruction.h
source/opt/ir_context.cpp
source/opt/local_access_chain_convert_pass.cpp
source/opt/local_single_block_elim_pass.cpp
source/opt/local_single_store_elim_pass.cpp
source/opt/mem_pass.cpp
source/opt/private_to_local_pass.cpp
source/opt/propagator.cpp
source/opt/scalar_replacement_pass.cpp
source/opt/set_spec_constant_default_value_pass.cpp

index 2e27a4b..ec32537 100644 (file)
@@ -110,6 +110,14 @@ class BasicBlock {
   inline void ForEachInst(const std::function<void(const Instruction*)>& f,
                           bool run_on_debug_line_insts = false) const;
 
+  // Runs the given function |f| on each instruction in this basic block, and
+  // optionally on the debug line instructions that might precede them. If |f|
+  // returns false, iteration is terminated and this function returns false.
+  inline bool WhileEachInst(const std::function<bool(Instruction*)>& f,
+                            bool run_on_debug_line_insts = false);
+  inline bool WhileEachInst(const std::function<bool(const Instruction*)>& f,
+                            bool run_on_debug_line_insts = false) const;
+
   // Runs the given function |f| on each Phi instruction in this basic block,
   // and optionally on the debug line instructions that might precede them.
   inline void ForEachPhiInst(const std::function<void(Instruction*)>& f,
@@ -181,30 +189,59 @@ inline void BasicBlock::AddInstructions(BasicBlock* bp) {
   (void)bEnd.MoveBefore(&bp->insts_);
 }
 
-inline void BasicBlock::ForEachInst(const std::function<void(Instruction*)>& f,
-                                    bool run_on_debug_line_insts) {
-  if (label_) label_->ForEachInst(f, run_on_debug_line_insts);
+inline bool BasicBlock::WhileEachInst(
+    const std::function<bool(Instruction*)>& f, bool run_on_debug_line_insts) {
+  if (label_) {
+    if (!label_->WhileEachInst(f, run_on_debug_line_insts)) return false;
+  }
   if (insts_.empty()) {
-    return;
+    return true;
   }
 
   Instruction* inst = &insts_.front();
   while (inst != nullptr) {
     Instruction* next_instruction = inst->NextNode();
-    inst->ForEachInst(f, run_on_debug_line_insts);
+    if (!inst->WhileEachInst(f, run_on_debug_line_insts)) return false;
     inst = next_instruction;
   }
+  return true;
+}
+
+inline bool BasicBlock::WhileEachInst(
+    const std::function<bool(const Instruction*)>& f,
+    bool run_on_debug_line_insts) const {
+  if (label_) {
+    if (!static_cast<const Instruction*>(label_.get())
+             ->WhileEachInst(f, run_on_debug_line_insts))
+      return false;
+  }
+  for (const auto& inst : insts_) {
+    if (!static_cast<const Instruction*>(&inst)->WhileEachInst(
+            f, run_on_debug_line_insts))
+      return false;
+  }
+  return true;
+}
+
+inline void BasicBlock::ForEachInst(const std::function<void(Instruction*)>& f,
+                                    bool run_on_debug_line_insts) {
+  WhileEachInst(
+      [&f](Instruction* inst) {
+        f(inst);
+        return true;
+      },
+      run_on_debug_line_insts);
 }
 
 inline void BasicBlock::ForEachInst(
     const std::function<void(const Instruction*)>& f,
     bool run_on_debug_line_insts) const {
-  if (label_)
-    static_cast<const Instruction*>(label_.get())
-        ->ForEachInst(f, run_on_debug_line_insts);
-  for (const auto& inst : insts_)
-    static_cast<const Instruction*>(&inst)->ForEachInst(
-        f, run_on_debug_line_insts);
+  WhileEachInst(
+      [&f](const Instruction* inst) {
+        f(inst);
+        return true;
+      },
+      run_on_debug_line_insts);
 }
 
 inline void BasicBlock::ForEachPhiInst(
index 24d19f6..c6a391e 100644 (file)
@@ -23,13 +23,15 @@ namespace spvtools {
 namespace opt {
 
 bool BlockMergePass::HasMultipleRefs(uint32_t labId) {
-  int rcnt = 0;
-  get_def_use_mgr()->ForEachUser(labId, [&rcnt](ir::Instruction* user) {
-    if (user->opcode() != SpvOpName) {
-      ++rcnt;
-    }
-  });
-  return rcnt > 1;
+  bool multiple_refs = false;
+  return !get_def_use_mgr()->WhileEachUser(
+      labId, [&multiple_refs](ir::Instruction* user) {
+        if (user->opcode() != SpvOpName) {
+          if (multiple_refs) return false;
+          multiple_refs = true;
+        }
+        return true;
+      });
 }
 
 void BlockMergePass::KillInstAndName(ir::Instruction* inst) {
index 077bf78..2f10f3a 100644 (file)
@@ -140,15 +140,11 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) {
 
   // If not, see if there is a least one unknown operand to the instruction.  If
   // so, we might be able to fold it later.
-  bool could_be_improved = false;
-  instr->ForEachInId([this, &could_be_improved](uint32_t* op_id) {
-    auto it = values_.find(*op_id);
-    if (it == values_.end()) {
-      could_be_improved = true;
-      return;
-    }
-  });
-  if (could_be_improved) {
+  if (!instr->WhileEachInId([this](uint32_t* op_id) {
+        auto it = values_.find(*op_id);
+        if (it == values_.end()) return false;
+        return true;
+      })) {
     return SSAPropagator::kNotInteresting;
   }
 
index fac8257..551929d 100644 (file)
@@ -52,12 +52,13 @@ bool CommonUniformElimPass::IsSamplerOrImageType(
   }
   if (typeInst->opcode() != SpvOpTypeStruct) return false;
   // Return true if any member is a sampler or image
-  int samplerOrImageCnt = 0;
-  typeInst->ForEachInId([&samplerOrImageCnt, this](const uint32_t* tid) {
+  return !typeInst->WhileEachInId([this](const uint32_t* tid) {
     const ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
-    if (IsSamplerOrImageType(compTypeInst)) ++samplerOrImageCnt;
+    if (IsSamplerOrImageType(compTypeInst)) {
+      return false;
+    }
+    return true;
   });
-  return samplerOrImageCnt > 0;
 }
 
 bool CommonUniformElimPass::IsSamplerOrImageVar(uint32_t varId) const {
@@ -98,13 +99,9 @@ ir::Instruction* CommonUniformElimPass::GetPtr(ir::Instruction* ip,
 
 bool CommonUniformElimPass::IsVolatileStruct(uint32_t type_id) {
   assert(get_def_use_mgr()->GetDef(type_id)->opcode() == SpvOpTypeStruct);
-  bool has_volatile_deco = false;
-  get_decoration_mgr()->ForEachDecoration(
+  return !get_decoration_mgr()->WhileEachDecoration(
       type_id, SpvDecorationVolatile,
-      [&has_volatile_deco](const ir::Instruction&) {
-        has_volatile_deco = true;
-      });
-  return has_volatile_deco;
+      [](const ir::Instruction&) { return false; });
 }
 
 bool CommonUniformElimPass::IsAccessChainToVolatileStructType(
@@ -177,26 +174,18 @@ bool CommonUniformElimPass::IsUniformVar(uint32_t varId) {
 }
 
 bool CommonUniformElimPass::HasUnsupportedDecorates(uint32_t id) const {
-  bool nonTypeDecorate = false;
-  get_def_use_mgr()->ForEachUser(
-      id, [this, &nonTypeDecorate](ir::Instruction* user) {
-        if (this->IsNonTypeDecorate(user->opcode())) {
-          nonTypeDecorate = true;
-        }
-      });
-  return nonTypeDecorate;
+  return !get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) {
+    if (IsNonTypeDecorate(user->opcode())) return false;
+    return true;
+  });
 }
 
 bool CommonUniformElimPass::HasOnlyNamesAndDecorates(uint32_t id) const {
-  bool onlyNameAndDecorates = true;
-  get_def_use_mgr()->ForEachUser(
-      id, [this, &onlyNameAndDecorates](ir::Instruction* user) {
-        SpvOp op = user->opcode();
-        if (op != SpvOpName && !this->IsNonTypeDecorate(op)) {
-          onlyNameAndDecorates = false;
-        }
-      });
-  return onlyNameAndDecorates;
+  return get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) {
+    SpvOp op = user->opcode();
+    if (op != SpvOpName && !IsNonTypeDecorate(op)) return false;
+    return true;
+  });
 }
 
 void CommonUniformElimPass::DeleteIfUseless(ir::Instruction* inst) {
@@ -267,15 +256,13 @@ void CommonUniformElimPass::GenACLoadRepl(
 
 bool CommonUniformElimPass::IsConstantIndexAccessChain(ir::Instruction* acp) {
   uint32_t inIdx = 0;
-  uint32_t nonConstCnt = 0;
-  acp->ForEachInId([&inIdx, &nonConstCnt, this](uint32_t* tid) {
+  return acp->WhileEachInId([&inIdx, this](uint32_t* tid) {
     if (inIdx > 0) {
       ir::Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
-      if (opInst->opcode() != SpvOpConstant) ++nonConstCnt;
+      if (opInst->opcode() != SpvOpConstant) return false;
     }
-    ++inIdx;
+    return true;
   });
-  return nonConstCnt == 0;
 }
 
 bool CommonUniformElimPass::UniformAccessChainConvert(ir::Function* func) {
index 0ac61cc..bdbf385 100644 (file)
@@ -117,7 +117,7 @@ bool DeadBranchElimPass::MarkLiveBlocks(
         // corresponding label, use default if not found.
         uint32_t icnt = 0;
         uint32_t case_val;
-        terminator->ForEachInOperand(
+        terminator->WhileEachInOperand(
             [&icnt, &case_val, &sel_val, &live_lab_id](const uint32_t* idp) {
               if (icnt == 1) {
                 // Start with default label.
@@ -126,10 +126,14 @@ bool DeadBranchElimPass::MarkLiveBlocks(
                 if (icnt % 2 == 0) {
                   case_val = *idp;
                 } else {
-                  if (case_val == sel_val) live_lab_id = *idp;
+                  if (case_val == sel_val) {
+                    live_lab_id = *idp;
+                    return false;
+                  }
                 }
               }
               ++icnt;
+              return true;
             });
       }
     }
index c4c63dd..e59280c 100644 (file)
@@ -225,26 +225,36 @@ std::vector<T> DecorationManager::InternalGetDecorationsFor(
   return decorations;
 }
 
-void DecorationManager::ForEachDecoration(
+bool DecorationManager::WhileEachDecoration(
     uint32_t id, uint32_t decoration,
-    std::function<void(const ir::Instruction&)> f) {
+    std::function<bool(const ir::Instruction&)> f) {
   for (const ir::Instruction* inst : GetDecorationsFor(id, true)) {
     switch (inst->opcode()) {
       case SpvOpMemberDecorate:
         if (inst->GetSingleWordInOperand(2) == decoration) {
-          f(*inst);
+          if (!f(*inst)) return false;
         }
         break;
       case SpvOpDecorate:
       case SpvOpDecorateId:
         if (inst->GetSingleWordInOperand(1) == decoration) {
-          f(*inst);
+          if (!f(*inst)) return false;
         }
         break;
       default:
         assert(false && "Unexpected decoration instruction");
     }
   }
+  return true;
+}
+
+void DecorationManager::ForEachDecoration(
+    uint32_t id, uint32_t decoration,
+    std::function<void(const ir::Instruction&)> f) {
+  WhileEachDecoration(id, decoration, [&f](const ir::Instruction& inst) {
+    f(inst);
+    return true;
+  });
 }
 
 void DecorationManager::CloneDecorations(
index d0c3da8..713b1e8 100644 (file)
@@ -69,6 +69,13 @@ class DecorationManager {
   void ForEachDecoration(uint32_t id, uint32_t decoration,
                          std::function<void(const ir::Instruction&)> f);
 
+  // |f| is run on each decoration instruction for |id| with decoration
+  // |decoration|. Processes all decoration which target |id| either directly or
+  // indirectly through decoration groups. If |f| returns false, iteration is
+  // terminated and this function returns false.
+  bool WhileEachDecoration(uint32_t id, uint32_t decoration,
+                           std::function<bool(const ir::Instruction&)> f);
+
   // Clone all decorations from one id |from|.
   // The cloned decorations are assigned to the given id |to| and are
   // added to the module. The purpose is to decorate cloned instructions.
index 3dbc51a..2b1b00c 100644 (file)
@@ -96,16 +96,31 @@ bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter,
   return UsersNotEnd(iter, id_to_users_.end(), inst);
 }
 
-void DefUseManager::ForEachUser(
+bool DefUseManager::WhileEachUser(
     const ir::Instruction* def,
-    const std::function<void(ir::Instruction*)>& f) const {
+    const std::function<bool(ir::Instruction*)>& f) const {
   // Ensure that |def| has been registered.
   assert(def && def == GetDef(def->result_id()) &&
          "Definition is not registered.");
   auto end = id_to_users_.end();
   for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
-    f(iter->second);
+    if (!f(iter->second)) return false;
   }
+  return true;
+}
+
+bool DefUseManager::WhileEachUser(
+    uint32_t id, const std::function<bool(ir::Instruction*)>& f) const {
+  return WhileEachUser(GetDef(id), f);
+}
+
+void DefUseManager::ForEachUser(
+    const ir::Instruction* def,
+    const std::function<void(ir::Instruction*)>& f) const {
+  WhileEachUser(def, [&f](ir::Instruction* user) {
+    f(user);
+    return true;
+  });
 }
 
 void DefUseManager::ForEachUser(
@@ -113,9 +128,9 @@ void DefUseManager::ForEachUser(
   ForEachUser(GetDef(id), f);
 }
 
-void DefUseManager::ForEachUse(
+bool DefUseManager::WhileEachUse(
     const ir::Instruction* def,
-    const std::function<void(ir::Instruction*, uint32_t)>& f) const {
+    const std::function<bool(ir::Instruction*, uint32_t)>& f) const {
   // Ensure that |def| has been registered.
   assert(def && def == GetDef(def->result_id()) &&
          "Definition is not registered.");
@@ -125,10 +140,28 @@ void DefUseManager::ForEachUse(
     for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) {
       const ir::Operand& op = user->GetOperand(idx);
       if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) {
-        if (def->result_id() == op.words[0]) f(user, idx);
+        if (def->result_id() == op.words[0]) {
+          if (!f(user, idx)) return false;
+        }
       }
     }
   }
+  return true;
+}
+
+bool DefUseManager::WhileEachUse(
+    uint32_t id,
+    const std::function<bool(ir::Instruction*, uint32_t)>& f) const {
+  return WhileEachUse(GetDef(id), f);
+}
+
+void DefUseManager::ForEachUse(
+    const ir::Instruction* def,
+    const std::function<void(ir::Instruction*, uint32_t)>& f) const {
+  WhileEachUse(def, [&f](ir::Instruction* user, uint32_t index) {
+    f(user, index);
+    return true;
+  });
 }
 
 void DefUseManager::ForEachUse(
index 1a8d989..69047e2 100644 (file)
@@ -137,6 +137,19 @@ class DefUseManager {
   void ForEachUser(uint32_t id,
                    const std::function<void(ir::Instruction*)>& f) const;
 
+  // Runs the given function |f| on each unique user instruction of |def| (or
+  // |id|). If |f| returns false, iteration is terminated and this function
+  // returns false.
+  //
+  // If one instruction uses |def| in multiple operands, that instruction will
+  // be only be visited once.
+  //
+  // |def| (or |id|) must be registered as a definition.
+  bool WhileEachUser(const ir::Instruction* def,
+                     const std::function<bool(ir::Instruction*)>& f) const;
+  bool WhileEachUser(uint32_t id,
+                     const std::function<bool(ir::Instruction*)>& f) const;
+
   // Runs the given function |f| on each unique use of |def| (or
   // |id|).
   //
@@ -151,6 +164,21 @@ class DefUseManager {
                   const std::function<void(ir::Instruction*,
                                            uint32_t operand_index)>& f) const;
 
+  // Runs the given function |f| on each unique use of |def| (or
+  // |id|). If |f| returns false, iteration is terminated and this function
+  // returns false.
+  //
+  // If one instruction uses |def| in multiple operands, each operand will be
+  // visited separately.
+  //
+  // |def| (or |id|) must be registered as a definition.
+  bool WhileEachUse(const ir::Instruction* def,
+                    const std::function<bool(ir::Instruction*,
+                                             uint32_t operand_index)>& f) const;
+  bool WhileEachUse(uint32_t id,
+                    const std::function<bool(ir::Instruction*,
+                                             uint32_t operand_index)>& f) const;
+
   // Returns the number of users of |def| (or |id|).
   uint32_t NumUsers(const ir::Instruction* def) const;
   uint32_t NumUsers(uint32_t id) const;
index 2f6c669..e3f9f21 100644 (file)
@@ -41,11 +41,10 @@ bool InlineOpaquePass::IsOpaqueType(uint32_t typeId) {
   // TODO(greg-lunarg): Handle arrays containing opaque type
   if (typeInst->opcode() != SpvOpTypeStruct) return false;
   // Return true if any member is opaque
-  int ocnt = 0;
-  typeInst->ForEachInId([&ocnt, this](const uint32_t* tid) {
-    if (ocnt == 0 && IsOpaqueType(*tid)) ++ocnt;
+  return !typeInst->WhileEachInId([this](const uint32_t* tid) {
+    if (IsOpaqueType(*tid)) return false;
+    return true;
   });
-  return ocnt > 0;
 }
 
 bool InlineOpaquePass::HasOpaqueArgsOrReturn(const ir::Instruction* callInst) {
@@ -53,15 +52,14 @@ bool InlineOpaquePass::HasOpaqueArgsOrReturn(const ir::Instruction* callInst) {
   if (IsOpaqueType(callInst->type_id())) return true;
   // Check args
   int icnt = 0;
-  int ocnt = 0;
-  callInst->ForEachInId([&icnt, &ocnt, this](const uint32_t* iid) {
+  return !callInst->WhileEachInId([&icnt, this](const uint32_t* iid) {
     if (icnt > 0) {
       const ir::Instruction* argInst = get_def_use_mgr()->GetDef(*iid);
-      if (IsOpaqueType(argInst->type_id())) ++ocnt;
+      if (IsOpaqueType(argInst->type_id())) return false;
     }
     ++icnt;
+    return true;
   });
-  return ocnt > 0;
 }
 
 bool InlineOpaquePass::InlineOpaque(ir::Function* func) {
index 27a6cf9..f1c98ed 100644 (file)
@@ -235,21 +235,42 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   inline void ForEachInst(const std::function<void(const Instruction*)>& f,
                           bool run_on_debug_line_insts = false) const;
 
+  // Runs the given function |f| on this instruction and optionally on the
+  // preceding debug line instructions.  The function will always be run
+  // if this is itself a debug line instruction. If |f| returns false,
+  // iteration is terminated and this function returns false.
+  inline bool WhileEachInst(const std::function<bool(Instruction*)>& f,
+                            bool run_on_debug_line_insts = false);
+  inline bool WhileEachInst(const std::function<bool(const Instruction*)>& f,
+                            bool run_on_debug_line_insts = false) const;
+
   // Runs the given function |f| on all operand ids.
   //
   // |f| should not transform an ID into 0, as 0 is an invalid ID.
   inline void ForEachId(const std::function<void(uint32_t*)>& f);
   inline void ForEachId(const std::function<void(const uint32_t*)>& f) const;
 
-  // Runs the given function |f| on all "in" operand ids
+  // Runs the given function |f| on all "in" operand ids.
   inline void ForEachInId(const std::function<void(uint32_t*)>& f);
   inline void ForEachInId(const std::function<void(const uint32_t*)>& f) const;
 
-  // Runs the given function |f| on all "in" operands
+  // Runs the given function |f| on all "in" operand ids. If |f| returns false,
+  // iteration is terminated and this function returns false.
+  inline bool WhileEachInId(const std::function<bool(uint32_t*)>& f);
+  inline bool WhileEachInId(
+      const std::function<bool(const uint32_t*)>& f) const;
+
+  // Runs the given function |f| on all "in" operands.
   inline void ForEachInOperand(const std::function<void(uint32_t*)>& f);
   inline void ForEachInOperand(
       const std::function<void(const uint32_t*)>& f) const;
 
+  // Runs the given function |f| on all "in" operands. If |f| returns false,
+  // iteration is terminated and this function return false.
+  inline bool WhileEachInOperand(const std::function<bool(uint32_t*)>& f);
+  inline bool WhileEachInOperand(
+      const std::function<bool(const uint32_t*)>& f) const;
+
   // Returns true if any operands can be labels
   inline bool HasLabels() const;
 
@@ -460,19 +481,46 @@ inline void Instruction::ToNop() {
   operands_.clear();
 }
 
+inline bool Instruction::WhileEachInst(
+    const std::function<bool(Instruction*)>& f, bool run_on_debug_line_insts) {
+  if (run_on_debug_line_insts) {
+    for (auto& dbg_line : dbg_line_insts_) {
+      if (!f(&dbg_line)) return false;
+    }
+  }
+  return f(this);
+}
+
+inline bool Instruction::WhileEachInst(
+    const std::function<bool(const Instruction*)>& f,
+    bool run_on_debug_line_insts) const {
+  if (run_on_debug_line_insts) {
+    for (auto& dbg_line : dbg_line_insts_) {
+      if (!f(&dbg_line)) return false;
+    }
+  }
+  return f(this);
+}
+
 inline void Instruction::ForEachInst(const std::function<void(Instruction*)>& f,
                                      bool run_on_debug_line_insts) {
-  if (run_on_debug_line_insts)
-    for (auto& dbg_line : dbg_line_insts_) f(&dbg_line);
-  f(this);
+  WhileEachInst(
+      [&f](Instruction* inst) {
+        f(inst);
+        return true;
+      },
+      run_on_debug_line_insts);
 }
 
 inline void Instruction::ForEachInst(
     const std::function<void(const Instruction*)>& f,
     bool run_on_debug_line_insts) const {
-  if (run_on_debug_line_insts)
-    for (auto& dbg_line : dbg_line_insts_) f(&dbg_line);
-  f(this);
+  WhileEachInst(
+      [&f](const Instruction* inst) {
+        f(inst);
+        return true;
+      },
+      run_on_debug_line_insts);
 }
 
 inline void Instruction::ForEachId(const std::function<void(uint32_t*)>& f) {
@@ -489,59 +537,99 @@ inline void Instruction::ForEachId(
     if (spvIsIdType(opnd.type)) f(&opnd.words[0]);
 }
 
-inline void Instruction::ForEachInId(const std::function<void(uint32_t*)>& f) {
+inline bool Instruction::WhileEachInId(
+    const std::function<bool(uint32_t*)>& f) {
   for (auto& opnd : operands_) {
     switch (opnd.type) {
       case SPV_OPERAND_TYPE_RESULT_ID:
       case SPV_OPERAND_TYPE_TYPE_ID:
         break;
       default:
-        if (spvIsIdType(opnd.type)) f(&opnd.words[0]);
+        if (spvIsIdType(opnd.type)) {
+          if (!f(&opnd.words[0])) return false;
+        }
         break;
     }
   }
+  return true;
 }
 
-inline void Instruction::ForEachInId(
-    const std::function<void(const uint32_t*)>& f) const {
+inline bool Instruction::WhileEachInId(
+    const std::function<bool(const uint32_t*)>& f) const {
   for (const auto& opnd : operands_) {
     switch (opnd.type) {
       case SPV_OPERAND_TYPE_RESULT_ID:
       case SPV_OPERAND_TYPE_TYPE_ID:
         break;
       default:
-        if (spvIsIdType(opnd.type)) f(&opnd.words[0]);
+        if (spvIsIdType(opnd.type)) {
+          if (!f(&opnd.words[0])) return false;
+        }
         break;
     }
   }
+  return true;
 }
 
-inline void Instruction::ForEachInOperand(
-    const std::function<void(uint32_t*)>& f) {
+inline void Instruction::ForEachInId(const std::function<void(uint32_t*)>& f) {
+  WhileEachInId([&f](uint32_t* id) {
+    f(id);
+    return true;
+  });
+}
+
+inline void Instruction::ForEachInId(
+    const std::function<void(const uint32_t*)>& f) const {
+  WhileEachInId([&f](const uint32_t* id) {
+    f(id);
+    return true;
+  });
+}
+
+inline bool Instruction::WhileEachInOperand(
+    const std::function<bool(uint32_t*)>& f) {
   for (auto& opnd : operands_) {
     switch (opnd.type) {
       case SPV_OPERAND_TYPE_RESULT_ID:
       case SPV_OPERAND_TYPE_TYPE_ID:
         break;
       default:
-        f(&opnd.words[0]);
+        if (!f(&opnd.words[0])) return false;
         break;
     }
   }
+  return true;
 }
 
-inline void Instruction::ForEachInOperand(
-    const std::function<void(const uint32_t*)>& f) const {
+inline bool Instruction::WhileEachInOperand(
+    const std::function<bool(const uint32_t*)>& f) const {
   for (const auto& opnd : operands_) {
     switch (opnd.type) {
       case SPV_OPERAND_TYPE_RESULT_ID:
       case SPV_OPERAND_TYPE_TYPE_ID:
         break;
       default:
-        f(&opnd.words[0]);
+        if (!f(&opnd.words[0])) return false;
         break;
     }
   }
+  return true;
+}
+
+inline void Instruction::ForEachInOperand(
+    const std::function<void(uint32_t*)>& f) {
+  WhileEachInOperand([&f](uint32_t* op) {
+    f(op);
+    return true;
+  });
+}
+
+inline void Instruction::ForEachInOperand(
+    const std::function<void(const uint32_t*)>& f) const {
+  WhileEachInOperand([&f](const uint32_t* op) {
+    f(op);
+    return true;
+  });
 }
 
 inline bool Instruction::HasLabels() const {
index 1e2da41..c332d8c 100644 (file)
@@ -184,13 +184,11 @@ bool IRContext::IsConsistent() {
   if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) {
     for (auto& func : *module()) {
       for (auto& block : func) {
-        bool ok = true;
-        block.ForEachInst([this, &block, &ok](ir::Instruction* inst) {
-          if (get_instr_block(inst) != &block) {
-            ok = false;
-          }
-        });
-        if (!ok) return false;
+        if (!block.WhileEachInst([this, &block](ir::Instruction* inst) {
+              if (get_instr_block(inst) != &block) return false;
+              return true;
+            }))
+          return false;
       }
     }
   }
index 81854f4..f686225 100644 (file)
@@ -117,36 +117,34 @@ void LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
 bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
     const ir::Instruction* acp) const {
   uint32_t inIdx = 0;
-  uint32_t nonConstCnt = 0;
-  acp->ForEachInId([&inIdx, &nonConstCnt, this](const uint32_t* tid) {
+  return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
     if (inIdx > 0) {
       ir::Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
-      if (opInst->opcode() != SpvOpConstant) ++nonConstCnt;
+      if (opInst->opcode() != SpvOpConstant) return false;
     }
     ++inIdx;
+    return true;
   });
-  return nonConstCnt == 0;
 }
 
 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
-  bool hasOnlySupportedRefs = true;
-  get_def_use_mgr()->ForEachUser(
-      ptrId, [this, &hasOnlySupportedRefs](ir::Instruction* user) {
+  if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) {
         SpvOp op = user->opcode();
         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
           if (!HasOnlySupportedRefs(user->result_id())) {
-            hasOnlySupportedRefs = false;
+            return false;
           }
         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
                    !IsNonTypeDecorate(op)) {
-          hasOnlySupportedRefs = false;
+          return false;
         }
-      });
-  if (hasOnlySupportedRefs) {
+        return true;
+      })) {
     supported_ref_ptrs_.insert(ptrId);
+    return true;
   }
-  return hasOnlySupportedRefs;
+  return false;
 }
 
 void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) {
index 6bb8ffd..b872535 100644 (file)
@@ -29,23 +29,22 @@ const uint32_t kStoreValIdInIdx = 1;
 
 bool LocalSingleBlockLoadStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) {
   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
-  bool hasOnlySupportedRefs = true;
-  get_def_use_mgr()->ForEachUser(
-      ptrId, [this, &hasOnlySupportedRefs](ir::Instruction* user) {
+  if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) {
         SpvOp op = user->opcode();
         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
           if (!HasOnlySupportedRefs(user->result_id())) {
-            hasOnlySupportedRefs = false;
+            return false;
           }
         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
-                   !this->IsNonTypeDecorate(op)) {
-          hasOnlySupportedRefs = false;
+                   !IsNonTypeDecorate(op)) {
+          return false;
         }
-      });
-  if (hasOnlySupportedRefs) {
+        return true;
+      })) {
     supported_ref_ptrs_.insert(ptrId);
+    return true;
   }
-  return hasOnlySupportedRefs;
+  return false;
 }
 
 bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
index c59a9dd..a460492 100644 (file)
@@ -32,23 +32,22 @@ const uint32_t kVariableInitIdInIdx = 1;
 
 bool LocalSingleStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) {
   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
-  bool hasOnlySupportedRefs = true;
-  get_def_use_mgr()->ForEachUser(
-      ptrId, [this, &hasOnlySupportedRefs](ir::Instruction* user) {
+  if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) {
         SpvOp op = user->opcode();
         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
           if (!HasOnlySupportedRefs(user->result_id())) {
-            hasOnlySupportedRefs = false;
+            return false;
           }
         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
                    !IsNonTypeDecorate(op)) {
-          hasOnlySupportedRefs = false;
+          return false;
         }
-      });
-  if (hasOnlySupportedRefs) {
+        return true;
+      })) {
     supported_ref_ptrs_.insert(ptrId);
+    return true;
   }
-  return hasOnlySupportedRefs;
+  return false;
 }
 
 void LocalSingleStoreElimPass::SingleStoreAnalyze(ir::Function* func) {
index 7177939..d0487ce 100644 (file)
@@ -66,12 +66,11 @@ bool MemPass::IsTargetType(const ir::Instruction* typeInst) const {
   }
   if (typeInst->opcode() != SpvOpTypeStruct) return false;
   // All struct members must be math type
-  int nonMathComp = 0;
-  typeInst->ForEachInId([&nonMathComp, this](const uint32_t* tid) {
+  return typeInst->WhileEachInId([this](const uint32_t* tid) {
     ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
-    if (!IsTargetType(compTypeInst)) ++nonMathComp;
+    if (!IsTargetType(compTypeInst)) return false;
+    return true;
   });
-  return nonMathComp == 0;
 }
 
 bool MemPass::IsNonPtrAccessChain(const SpvOp opcode) const {
@@ -127,15 +126,13 @@ ir::Instruction* MemPass::GetPtr(ir::Instruction* ip, uint32_t* varId) {
 }
 
 bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
-  bool hasOnlyNamesAndDecorates = true;
-  get_def_use_mgr()->ForEachUser(
-      id, [this, &hasOnlyNamesAndDecorates](ir::Instruction* user) {
-        SpvOp op = user->opcode();
-        if (op != SpvOpName && !IsNonTypeDecorate(op)) {
-          hasOnlyNamesAndDecorates = false;
-        }
-      });
-  return hasOnlyNamesAndDecorates;
+  return get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) {
+    SpvOp op = user->opcode();
+    if (op != SpvOpName && !IsNonTypeDecorate(op)) {
+      return false;
+    }
+    return true;
+  });
 }
 
 void MemPass::KillAllInsts(ir::BasicBlock* bp, bool killLabel) {
@@ -147,21 +144,20 @@ void MemPass::KillAllInsts(ir::BasicBlock* bp, bool killLabel) {
 }
 
 bool MemPass::HasLoads(uint32_t varId) const {
-  bool hasLoads = false;
-  get_def_use_mgr()->ForEachUser(varId, [this,
-                                         &hasLoads](ir::Instruction* user) {
+  return !get_def_use_mgr()->WhileEachUser(varId, [this](
+                                                      ir::Instruction* user) {
     SpvOp op = user->opcode();
     // TODO(): The following is slightly conservative. Could be
     // better handling of non-store/name.
     if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
       if (HasLoads(user->result_id())) {
-        hasLoads = true;
+        return false;
       }
     } else if (op != SpvOpStore && op != SpvOpName && !IsNonTypeDecorate(op)) {
-      hasLoads = true;
+      return false;
     }
+    return true;
   });
-  return hasLoads;
 }
 
 bool MemPass::IsLiveVar(uint32_t varId) const {
@@ -238,16 +234,14 @@ MemPass::MemPass() {}
 
 bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
   if (supported_ref_vars_.find(varId) != supported_ref_vars_.end()) return true;
-  bool hasOnlySupportedRefs = true;
-  get_def_use_mgr()->ForEachUser(
-      varId, [this, &hasOnlySupportedRefs](ir::Instruction* user) {
-        SpvOp op = user->opcode();
-        if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
-            !IsNonTypeDecorate(op)) {
-          hasOnlySupportedRefs = false;
-        }
-      });
-  return hasOnlySupportedRefs;
+  return get_def_use_mgr()->WhileEachUser(varId, [this](ir::Instruction* user) {
+    SpvOp op = user->opcode();
+    if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
+        !IsNonTypeDecorate(op)) {
+      return false;
+    }
+    return true;
+  });
 }
 
 void MemPass::InitSSARewrite(ir::Function* func) {
index 3cdf13f..3af7720 100644 (file)
@@ -126,14 +126,12 @@ bool PrivateToLocalPass::IsValidUse(const ir::Instruction* inst) const {
     case SpvOpLoad:
     case SpvOpStore:
       return true;
-    case SpvOpAccessChain: {
-      bool valid = true;
-      context()->get_def_use_mgr()->ForEachUser(
-          inst->result_id(), [this, &valid](const ir::Instruction* use) {
-            valid &= IsValidUse(use);
+    case SpvOpAccessChain:
+      return context()->get_def_use_mgr()->WhileEachUser(
+          inst, [this](const ir::Instruction* user) {
+            if (!IsValidUse(user)) return false;
+            return true;
           });
-      return valid;
-    }
     case SpvOpName:
       return true;
     default:
index 0fe4d6a..a046e55 100644 (file)
@@ -140,13 +140,14 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) {
     // For regular instructions, check if the defining instruction of each
     // operand needs to be simulated again.  If so, then this instruction should
     // also be simulated again.
-    instr->ForEachInId([&has_operands_to_simulate, this](const uint32_t* use) {
-      ir::Instruction* def_instr = get_def_use_mgr()->GetDef(*use);
-      if (ShouldSimulateAgain(def_instr)) {
-        has_operands_to_simulate = true;
-        return;
-      }
-    });
+    has_operands_to_simulate =
+        !instr->WhileEachInId([this](const uint32_t* use) {
+          ir::Instruction* def_instr = get_def_use_mgr()->GetDef(*use);
+          if (ShouldSimulateAgain(def_instr)) {
+            return false;
+          }
+          return true;
+        });
   }
 
   if (!has_operands_to_simulate) {
index 23b0ce5..8188a41 100644 (file)
@@ -77,38 +77,36 @@ bool ScalarReplacementPass::ReplaceVariable(
   std::vector<ir::Instruction*> replacements;
   CreateReplacementVariables(inst, &replacements);
 
-  bool ok = true;
   std::vector<ir::Instruction*> dead;
   dead.push_back(inst);
-  get_def_use_mgr()->ForEachUser(
-      inst, [this, &ok, &replacements, &dead](ir::Instruction* user) {
-        if (!ir::IsAnnotationInst(user->opcode())) {
-          switch (user->opcode()) {
-            case SpvOpLoad:
-              ReplaceWholeLoad(user, replacements);
-              dead.push_back(user);
-              break;
-            case SpvOpStore:
-              ReplaceWholeStore(user, replacements);
-              dead.push_back(user);
-              break;
-            case SpvOpAccessChain:
-            case SpvOpInBoundsAccessChain:
-              ok &= ReplaceAccessChain(user, replacements);
-              dead.push_back(user);
-              break;
-            case SpvOpName:
-            case SpvOpMemberName:
-              break;
-            default:
-              assert(false && "Unexpected opcode");
-              break;
-          }
-        }
-      });
-
-  // There was an illegal access.
-  if (!ok) return false;
+  if (!get_def_use_mgr()->WhileEachUser(
+          inst, [this, &replacements, &dead](ir::Instruction* user) {
+            if (!ir::IsAnnotationInst(user->opcode())) {
+              switch (user->opcode()) {
+                case SpvOpLoad:
+                  ReplaceWholeLoad(user, replacements);
+                  dead.push_back(user);
+                  break;
+                case SpvOpStore:
+                  ReplaceWholeStore(user, replacements);
+                  dead.push_back(user);
+                  break;
+                case SpvOpAccessChain:
+                case SpvOpInBoundsAccessChain:
+                  if (!ReplaceAccessChain(user, replacements)) return false;
+                  dead.push_back(user);
+                  break;
+                case SpvOpName:
+                case SpvOpMemberName:
+                  break;
+                default:
+                  assert(false && "Unexpected opcode");
+                  break;
+              }
+            }
+            return true;
+          }))
+    return false;
 
   // Clean up some dead code.
   while (!dead.empty()) {
@@ -127,7 +125,7 @@ bool ScalarReplacementPass::ReplaceVariable(
     }
   }
 
-  return ok;
+  return true;
 }
 
 void ScalarReplacementPass::ReplaceWholeLoad(
index 1b0cccd..3c66260 100644 (file)
@@ -137,13 +137,16 @@ ir::Instruction* GetSpecIdTargetFromDecorationGroup(
   // the first OpGroupDecoration instruction that uses the given decoration
   // group.
   ir::Instruction* group_decorate_inst = nullptr;
-  def_use_mgr->ForEachUser(&decoration_group_defining_inst,
-                           [&group_decorate_inst](ir::Instruction* user) {
-                             if (user->opcode() == SpvOp::SpvOpGroupDecorate) {
-                               group_decorate_inst = user;
-                             }
-                           });
-  if (!group_decorate_inst) return nullptr;
+  if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
+                                 [&group_decorate_inst](ir::Instruction* user) {
+                                   if (user->opcode() ==
+                                       SpvOp::SpvOpGroupDecorate) {
+                                     group_decorate_inst = user;
+                                     return false;
+                                   }
+                                   return true;
+                                 }))
+    return nullptr;
 
   // Scan through the target ids of the OpGroupDecorate instruction. There
   // should be only one spec constant target consumes the SpecId decoration.