Validator structured flow checks: back-edge, constructs
authorUmar Arshad <umar@arrayfire.com>
Sat, 4 Jun 2016 01:24:24 +0000 (21:24 -0400)
committerDavid Neto <dneto@google.com>
Wed, 22 Jun 2016 15:51:19 +0000 (11:51 -0400)
Skip structured control flow chekcs for non-shader capability.

Fix infinite loop in dominator algorithm when there's an
unreachable block.

source/val/BasicBlock.cpp
source/val/BasicBlock.h
source/val/Construct.cpp
source/val/Construct.h
source/val/Function.cpp
source/val/Function.h
source/val/ValidationState.h
source/validate.h
source/validate_cfg.cpp
test/Validate.CFG.cpp
test/ValidateFixtures.cpp

index 55be3b8..4325736 100644 (file)
@@ -35,25 +35,38 @@ namespace libspirv {
 BasicBlock::BasicBlock(uint32_t id)
     : id_(id),
       immediate_dominator_(nullptr),
+      immediate_post_dominator_(nullptr),
       predecessors_(),
       successors_(),
+      type_(0),
       reachable_(false) {}
 
 void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) {
   immediate_dominator_ = dom_block;
 }
 
+void BasicBlock::SetImmediatePostDominator(BasicBlock* pdom_block) {
+  immediate_post_dominator_ = pdom_block;
+}
+
 const BasicBlock* BasicBlock::GetImmediateDominator() const {
   return immediate_dominator_;
 }
 
+const BasicBlock* BasicBlock::GetImmediatePostDominator() const {
+  return immediate_post_dominator_;
+}
+
 BasicBlock* BasicBlock::GetImmediateDominator() { return immediate_dominator_; }
+BasicBlock* BasicBlock::GetImmediatePostDominator() {
+  return immediate_post_dominator_;
+}
 
-void BasicBlock::RegisterSuccessors(vector<BasicBlock*> next_blocks) {
+void BasicBlock::RegisterSuccessors(const vector<BasicBlock*>& next_blocks) {
   for (auto& block : next_blocks) {
     block->predecessors_.push_back(this);
     successors_.push_back(block);
-    if (block->reachable_ == false) block->set_reachability(reachable_);
+    if (block->reachable_ == false) block->set_reachable(reachable_);
   }
 }
 
@@ -63,24 +76,29 @@ void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
 }
 
 BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {}
-BasicBlock::DominatorIterator::DominatorIterator(const BasicBlock* block)
-    : current_(block) {}
+
+BasicBlock::DominatorIterator::DominatorIterator(
+    const BasicBlock* block,
+    std::function<const BasicBlock*(const BasicBlock*)> dominator_func)
+    : current_(block), dom_func_(dominator_func) {}
 
 BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() {
-  if (current_ == current_->GetImmediateDominator()) {
+  if (current_ == dom_func_(current_)) {
     current_ = nullptr;
   } else {
-    current_ = current_->GetImmediateDominator();
+    current_ = dom_func_(current_);
   }
   return *this;
 }
 
 const BasicBlock::DominatorIterator BasicBlock::dom_begin() const {
-  return DominatorIterator(this);
+  return DominatorIterator(
+      this, [](const BasicBlock* b) { return b->GetImmediateDominator(); });
 }
 
 BasicBlock::DominatorIterator BasicBlock::dom_begin() {
-  return DominatorIterator(this);
+  return DominatorIterator(
+      this, [](const BasicBlock* b) { return b->GetImmediateDominator(); });
 }
 
 const BasicBlock::DominatorIterator BasicBlock::dom_end() const {
@@ -91,6 +109,24 @@ BasicBlock::DominatorIterator BasicBlock::dom_end() {
   return DominatorIterator();
 }
 
+const BasicBlock::DominatorIterator BasicBlock::pdom_begin() const {
+  return DominatorIterator(
+      this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); });
+}
+
+BasicBlock::DominatorIterator BasicBlock::pdom_begin() {
+  return DominatorIterator(
+    this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); });
+}
+
+const BasicBlock::DominatorIterator BasicBlock::pdom_end() const {
+  return DominatorIterator();
+}
+
+BasicBlock::DominatorIterator BasicBlock::pdom_end() {
+  return DominatorIterator();
+}
+
 bool operator==(const BasicBlock::DominatorIterator& lhs,
                 const BasicBlock::DominatorIterator& rhs) {
   return lhs.current_ == rhs.current_;
index 0cdc459..8818faa 100644 (file)
 #include "spirv/1.1/spirv.h"
 
 #include <cstdint>
+
+#include <bitset>
+#include <functional>
 #include <vector>
 
 namespace libspirv {
 
+enum BlockType : uint32_t {
+  kBlockTypeUndefined,
+  kBlockTypeHeader,
+  kBlockTypeLoop,
+  kBlockTypeMerge,
+  kBlockTypeBreak,
+  kBlockTypeContinue,
+  kBlockTypeReturn,
+  kBlockTypeCOUNT  ///< Total number of block types. (must be the last element)
+};
+
 // This class represents a basic block in a SPIR-V module
 class BasicBlock {
  public:
@@ -61,27 +75,53 @@ class BasicBlock {
   /// Returns the successors of the BasicBlock
   std::vector<BasicBlock*>* get_successors() { return &successors_; }
 
-  /// Returns true if the  block should be reachable in the CFG
+  /// Returns true if the block is reachable in the CFG
   bool is_reachable() const { return reachable_; }
 
-  void set_reachability(bool reachability) { reachable_ = reachability; }
+  /// Returns true if BasicBlock is of the given type
+  bool is_type(BlockType type) const {
+    if (type == kBlockTypeUndefined) return type_.none();
+    return type_.test(type);
+  }
+
+  /// Sets the reachability of the basic block in the CFG
+  void set_reachable(bool reachability) { reachable_ = reachability; }
+
+  /// Sets the type of the BasicBlock
+  void set_type(BlockType type) {
+    if (type == kBlockTypeUndefined)
+      type_.reset();
+    else
+      type_.set(type);
+  }
 
   /// Sets the immedate dominator of this basic block
   ///
   /// @param[in] dom_block The dominator block
   void SetImmediateDominator(BasicBlock* dom_block);
 
+  /// Sets the immedate post dominator of this basic block
+  ///
+  /// @param[in] pdom_block The post dominator block
+  void SetImmediatePostDominator(BasicBlock* pdom_block);
+
   /// Returns the immedate dominator of this basic block
   BasicBlock* GetImmediateDominator();
 
   /// Returns the immedate dominator of this basic block
   const BasicBlock* GetImmediateDominator() const;
 
+  /// Returns the immedate post dominator of this basic block
+  BasicBlock* GetImmediatePostDominator();
+
+  /// Returns the immedate post dominator of this basic block
+  const BasicBlock* GetImmediatePostDominator() const;
+
   /// Ends the block without a successor
   void RegisterBranchInstruction(SpvOp branch_instruction);
 
   /// Adds @p next BasicBlocks as successors of this BasicBlock
-  void RegisterSuccessors(std::vector<BasicBlock*> next = {});
+  void RegisterSuccessors(const std::vector<BasicBlock*>& next = {});
 
   /// Returns true if the id of the BasicBlock matches
   bool operator==(const BasicBlock& other) const { return other.id_ == id_; }
@@ -91,7 +131,7 @@ class BasicBlock {
 
   /// @brief A BasicBlock dominator iterator class
   ///
-  /// This iterator will iterate over the dominators of the block
+  /// This iterator will iterate over the (post)dominators of the block
   class DominatorIterator
       : public std::iterator<std::forward_iterator_tag, BasicBlock*> {
    public:
@@ -104,8 +144,12 @@ class BasicBlock {
     /// @brief Constructs an iterator for the given block which points to
     ///        @p block
     ///
-    /// @param block The block which is referenced by the iterator
-    explicit DominatorIterator(const BasicBlock* block);
+    /// @param block          The block which is referenced by the iterator
+    /// @param dominator_func This function will be called to get the immediate
+    ///                       (post)dominator of the current block
+    DominatorIterator(
+        const BasicBlock* block,
+        std::function<const BasicBlock*(const BasicBlock*)> dominator_func);
 
     /// @brief Advances the iterator
     DominatorIterator& operator++();
@@ -118,16 +162,36 @@ class BasicBlock {
 
    private:
     const BasicBlock* current_;
+    std::function<const BasicBlock*(const BasicBlock*)> dom_func_;
   };
 
-  /// Returns an iterator which points to the current block
+  /// Returns a dominator iterator which points to the current block
   const DominatorIterator dom_begin() const;
+
+  /// Returns a dominator iterator which points to the current block
   DominatorIterator dom_begin();
 
-  /// Returns an iterator which points to one element past the first block
+  /// Returns a dominator iterator which points to one element past the first
+  /// block
   const DominatorIterator dom_end() const;
+
+  /// Returns a dominator iterator which points to one element past the first
+  /// block
   DominatorIterator dom_end();
 
+  /// Returns a post dominator iterator which points to the current block
+  const DominatorIterator pdom_begin() const;
+  /// Returns a post dominator iterator which points to the current block
+  DominatorIterator pdom_begin();
+
+  /// Returns a post dominator iterator which points to one element past the
+  /// last block
+  const DominatorIterator pdom_end() const;
+
+  /// Returns a post dominator iterator which points to one element past the
+  /// last block
+  DominatorIterator pdom_end();
+
  private:
   /// Id of the BasicBlock
   const uint32_t id_;
@@ -135,12 +199,19 @@ class BasicBlock {
   /// Pointer to the immediate dominator of the BasicBlock
   BasicBlock* immediate_dominator_;
 
+  /// Pointer to the immediate dominator of the BasicBlock
+  BasicBlock* immediate_post_dominator_;
+
   /// The set of predecessors of the BasicBlock
   std::vector<BasicBlock*> predecessors_;
 
   /// The set of successors of the BasicBlock
   std::vector<BasicBlock*> successors_;
 
+  /// The type of the block
+  std::bitset<kBlockTypeCOUNT - 1> type_;
+
+  /// True if the block is reachable in the CFG
   bool reachable_;
 };
 
index 91140bf..6bd1b5d 100644 (file)
 
 #include "val/Construct.h"
 
+#include <cassert>
+#include <cstddef>
+
 namespace libspirv {
 
-Construct::Construct(BasicBlock* header_block, BasicBlock* merge_block,
-                     BasicBlock* continue_block)
-    : header_block_(header_block),
-      merge_block_(merge_block),
-      continue_block_(continue_block) {}
+Construct::Construct(ConstructType type, BasicBlock* entry,
+                     BasicBlock* exit, std::vector<Construct*> constructs)
+    : type_(type),
+      corresponding_constructs_(constructs),
+      entry_block_(entry),
+      exit_block_(exit) {}
+
+ConstructType Construct::get_type() const { return type_; }
+
+const std::vector<Construct*>& Construct::get_corresponding_constructs() const {
+  return corresponding_constructs_;
+}
+std::vector<Construct*>& Construct::get_corresponding_constructs() {
+  return corresponding_constructs_;
+}
+
+bool ValidateConstructSize(ConstructType type, size_t size) {
+  switch (type) {
+    case ConstructType::kSelection: return size == 0;
+    case ConstructType::kContinue:  return size == 1;
+    case ConstructType::kLoop:      return size == 1;
+    case ConstructType::kCase:      return size >= 1;
+    default: assert(1 == 0 && "Type not defined");
+  }
+  return false;
+}
+
+void Construct::set_corresponding_constructs(
+    std::vector<Construct*> constructs) {
+  assert(ValidateConstructSize(type_, constructs.size()));
+  corresponding_constructs_ = constructs;
+}
+
+const BasicBlock* Construct::get_entry() const { return entry_block_; }
+BasicBlock* Construct::get_entry() { return entry_block_; }
 
-const BasicBlock* Construct::get_header() const { return header_block_; }
-const BasicBlock* Construct::get_merge() const { return merge_block_; }
-const BasicBlock* Construct::get_continue() const { return continue_block_; }
+const BasicBlock* Construct::get_exit() const { return exit_block_; }
+BasicBlock* Construct::get_exit() { return exit_block_; }
 
-BasicBlock* Construct::get_header() { return header_block_; }
-BasicBlock* Construct::get_merge() { return merge_block_; }
-BasicBlock* Construct::get_continue() { return continue_block_; }
+void Construct::set_exit(BasicBlock* exit_block) {
+  exit_block_ = exit_block;
 }
+}  /// namespace libspirv
index ef5fae4..b87c99a 100644 (file)
 #define LIBSPIRV_VAL_CONSTRUCT_H_
 
 #include <cstdint>
+#include <vector>
 
 namespace libspirv {
 
+enum class ConstructType {
+  kNone,
+  /// The set of blocks dominated by a selection header, minus the set of blocks
+  /// dominated by the header's merge block
+  kSelection,
+  /// The set of blocks dominated by an OpLoopMerge's Continue Target and post
+  /// dominated by the corresponding back
+  kContinue,
+  ///  The set of blocks dominated by a loop header, minus the set of blocks
+  ///  dominated by the loop's merge block, minus the loop's corresponding
+  ///  continue construct
+  kLoop,
+  ///  The set of blocks dominated by an OpSwitch's Target or Default, minus the
+  ///  set of blocks dominated by the OpSwitch's merge block (this construct is
+  ///  only defined for those OpSwitch Target or Default that are not equal to
+  ///  the OpSwitch's corresponding merge block)
+  kCase
+};
+
 class BasicBlock;
 
 /// @brief This class tracks the CFG constructs as defined in the SPIR-V spec
 class Construct {
  public:
-  Construct(BasicBlock* header_block, BasicBlock* merge_block,
-            BasicBlock* continue_block = nullptr);
+  Construct(ConstructType type, BasicBlock* dominator,
+            BasicBlock* exit = nullptr,
+            std::vector<Construct*> constructs = {});
+
+  /// Returns the type of the construct
+  ConstructType get_type() const;
+
+  const std::vector<Construct*>& get_corresponding_constructs() const;
+  std::vector<Construct*>& get_corresponding_constructs();
+  void set_corresponding_constructs(std::vector<Construct*> constructs);
+
+  /// Returns the dominator block of the construct.
+  ///
+  /// This is usually the header block or the first block of the construct.
+  const BasicBlock* get_entry() const;
 
-  const BasicBlock* get_header() const;
-  const BasicBlock* get_merge() const;
-  const BasicBlock* get_continue() const;
+  /// Returns the dominator block of the construct.
+  ///
+  /// This is usually the header block or the first block of the construct.
+  BasicBlock* get_entry();
 
-  BasicBlock* get_header();
-  BasicBlock* get_merge();
-  BasicBlock* get_continue();
+  /// Returns the exit block of the construct.
+  ///
+  /// For a continue construct it is  the backedge block of the corresponding
+  /// loop construct. For the case  construct it is the block that branches to
+  /// the OpSwitch merge block or  other case blocks. Otherwise it is the merge
+  /// block of the corresponding  header block
+  const BasicBlock* get_exit() const;
+
+  /// Returns the exit block of the construct.
+  ///
+  /// For a continue construct it is  the backedge block of the corresponding
+  /// loop construct. For the case  construct it is the block that branches to
+  /// the OpSwitch merge block or  other case blocks. Otherwise it is the merge
+  /// block of the corresponding  header block
+  BasicBlock* get_exit();
+
+  /// Sets the exit block for this construct. This is useful for continue
+  /// constructs which do not know the back-edge block during construction
+  void set_exit(BasicBlock* exit_block);
 
  private:
-  BasicBlock* header_block_;    ///< The header block of a loop or selection
-  BasicBlock* merge_block_;     ///< The merge block of a loop or selection
-  BasicBlock* continue_block_;  ///< The continue block of a loop block
+  /// The type of the construct
+  ConstructType type_;
+
+  /// These are the constructs that are related to this construct. These
+  /// constructs can be the continue construct, for the corresponding loop
+  /// construct, the case construct that are part of the same OpSwitch
+  /// instruction
+  ///
+  /// Here is a table that describes what constructs are included in
+  /// @p corresponding_constructs_
+  /// | this construct | corresponding construct          |
+  /// |----------------|----------------------------------|
+  /// | loop           | continue                         |
+  /// | continue       | loop                             |
+  /// | case           | other cases in the same OpSwitch |
+  ///
+  /// kContinue and kLoop constructs will always have corresponding
+  /// constructs even if they are represented by the same block
+  std::vector<Construct*> corresponding_constructs_;
+
+  /// @brief Dominator block for the construct
+  ///
+  /// The dominator block for the construct. Depending on the construct this may
+  /// be a selection header, a continue target of a loop, a loop header or a
+  /// Target or Default block of a switch
+  BasicBlock* entry_block_;
+
+  /// @brief Exiting block for the construct
+  ///
+  /// The exit block for the construct. This can be a merge block for the loop
+  /// and selection constructs, a back-edge block for a continue construct, or
+  /// the branching block for the case construct
+  BasicBlock* exit_block_;
 };
 
 }  /// namespace libspirv
index 3756949..d2c89fb 100644 (file)
 #include <cassert>
 
 #include <algorithm>
+#include <utility>
 
 #include "val/BasicBlock.h"
 #include "val/Construct.h"
 #include "val/ValidationState.h"
 
+using std::ignore;
 using std::list;
+using std::make_pair;
+using std::pair;
 using std::string;
+using std::tie;
 using std::vector;
 
 namespace libspirv {
@@ -66,6 +71,7 @@ Function::Function(uint32_t id, uint32_t result_type_id,
       declaration_type_(FunctionDecl::kFunctionDeclUnknown),
       blocks_(),
       current_block_(nullptr),
+      pseudo_exit_block_(kInvalidId),
       cfg_constructs_(),
       variable_ids_(),
       parameter_ids_() {}
@@ -93,15 +99,33 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
                                          uint32_t continue_id) {
   RegisterBlock(merge_id, false);
   RegisterBlock(continue_id, false);
-  cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id),
-                               &blocks_.at(continue_id));
+  BasicBlock& merge_block = blocks_.at(merge_id);
+  BasicBlock& continue_block = blocks_.at(continue_id);
+  assert(current_block_ &&
+         "RegisterLoopMerge must be called when called within a block");
+
+  current_block_->set_type(kBlockTypeLoop);
+  merge_block.set_type(kBlockTypeMerge);
+  continue_block.set_type(kBlockTypeContinue);
+  cfg_constructs_.emplace_back(ConstructType::kLoop, current_block_,
+                               &merge_block);
+  Construct& loop_construct = cfg_constructs_.back();
+  cfg_constructs_.emplace_back(ConstructType::kContinue, &continue_block);
+  Construct& continue_construct = cfg_constructs_.back();
+  continue_construct.set_corresponding_constructs({&loop_construct});
+  loop_construct.set_corresponding_constructs({&continue_construct});
 
   return SPV_SUCCESS;
 }
 
 spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) {
   RegisterBlock(merge_id, false);
-  cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id));
+  BasicBlock& merge_block = blocks_.at(merge_id);
+  current_block_->set_type(kBlockTypeHeader);
+  merge_block.set_type(kBlockTypeMerge);
+
+  cfg_constructs_.emplace_back(ConstructType::kSelection, get_current_block(),
+                               &merge_block);
   return SPV_SUCCESS;
 }
 
@@ -152,7 +176,7 @@ spv_result_t Function::RegisterBlock(uint32_t id, bool is_definition) {
     undefined_blocks_.erase(id);
     current_block_ = &inserted_block->second;
     ordered_blocks_.push_back(current_block_);
-    if (IsFirstBlock(id)) current_block_->set_reachability(true);
+    if (IsFirstBlock(id)) current_block_->set_reachable(true);
   } else if (success) {  // Block doesn't exsist but this is not a definition
     undefined_blocks_.insert(id);
   }
@@ -182,6 +206,11 @@ void Function::RegisterBlockEnd(vector<uint32_t> next_list,
     next_blocks.push_back(&inserted_block->second);
   }
 
+  if (branch_instruction == SpvOpReturn ||
+      branch_instruction == SpvOpReturnValue) {
+    assert(next_blocks.empty());
+    next_blocks.push_back(&pseudo_exit_block_);
+  }
   current_block_->RegisterBranchInstruction(branch_instruction);
   current_block_->RegisterSuccessors(next_blocks);
   current_block_ = nullptr;
@@ -202,6 +231,11 @@ vector<BasicBlock*>& Function::get_blocks() { return ordered_blocks_; }
 const BasicBlock* Function::get_current_block() const { return current_block_; }
 BasicBlock* Function::get_current_block() { return current_block_; }
 
+BasicBlock* Function::get_pseudo_exit_block() { return &pseudo_exit_block_; }
+const BasicBlock* Function::get_pseudo_exit_block() const {
+  return &pseudo_exit_block_;
+}
+
 const list<Construct>& Function::get_constructs() const {
   return cfg_constructs_;
 }
@@ -216,17 +250,32 @@ BasicBlock* Function::get_first_block() {
   return ordered_blocks_[0];
 }
 
-bool Function::IsMergeBlock(uint32_t merge_block_id) const {
-  const auto b = blocks_.find(merge_block_id);
+bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const {
+  bool ret = false;
+  const BasicBlock* block;
+  tie(block, ignore) = GetBlock(merge_block_id);
+  if (block) {
+    ret = block->is_type(type);
+  }
+  return ret;
+}
+
+pair<const BasicBlock*, bool> Function::GetBlock(uint32_t id) const {
+  const auto b = blocks_.find(id);
   if (b != end(blocks_)) {
-    return cfg_constructs_.end() !=
-           find_if(begin(cfg_constructs_), end(cfg_constructs_),
-                   [&](const Construct& construct) {
-                     return construct.get_merge() == &b->second;
-                   });
+    const BasicBlock* block = &(b->second);
+    bool defined =
+        undefined_blocks_.find(block->get_id()) == end(undefined_blocks_);
+    return make_pair(block, defined);
   } else {
-    return false;
+    return make_pair(nullptr, false);
   }
 }
 
+pair<BasicBlock*, bool> Function::GetBlock(uint32_t id) {
+  const BasicBlock* out;
+  bool defined;
+  tie(out, defined) = const_cast<const Function*>(this)->GetBlock(id);
+  return make_pair(const_cast<BasicBlock*>(out), defined);
+}
 }  /// namespace libspirv
index c9f0746..4344fe9 100644 (file)
@@ -28,9 +28,9 @@
 #define LIBSPIRV_VAL_FUNCTION_H_
 
 #include <list>
-#include <vector>
-#include <unordered_set>
 #include <unordered_map>
+#include <unordered_set>
+#include <vector>
 
 #include "spirv/1.1/spirv.h"
 #include "spirv-tools/libspirv.h"
@@ -100,12 +100,19 @@ class Function {
   void RegisterBlockEnd(std::vector<uint32_t> successors_list,
                         SpvOp branch_instruction);
 
-  /// Returns true if the \p merge_block_id is a merge block
-  bool IsMergeBlock(uint32_t merge_block_id) const;
-
-  /// Returns true if the \p id is the first block of this function
+  /// Returns true if the \p id block is the first block of this function
   bool IsFirstBlock(uint32_t id) const;
 
+  /// Returns true if the \p merge_block_id is a BlockType of \p type
+  bool IsBlockType(uint32_t merge_block_id, BlockType type) const;
+
+  /// Returns a pair consisting of the BasicBlock with \p id and a bool
+  /// which is true if the block has been defined, and false if it is
+  /// declared but not defined. This function will return nullptr if the
+  /// \p id was not declared and not defined at the current point in the binary
+  std::pair<const BasicBlock*, bool> GetBlock(uint32_t id) const;
+  std::pair<BasicBlock*, bool> GetBlock(uint32_t id);
+
   /// Returns the first block of the current function
   const BasicBlock* get_first_block() const;
 
@@ -142,6 +149,12 @@ class Function {
   /// Returns the block that is currently being parsed in the binary
   const BasicBlock* get_current_block() const;
 
+  /// Returns the psudo exit block
+  BasicBlock* get_pseudo_exit_block();
+
+  /// Returns the psudo exit block
+  const BasicBlock* get_pseudo_exit_block() const;
+
   /// Prints a GraphViz digraph of the CFG of the current funciton
   void printDotGraph() const;
 
@@ -179,6 +192,9 @@ class Function {
   /// The block that is currently being parsed
   BasicBlock* current_block_;
 
+  /// A pseudo exit block that is the successor to all return blocks
+  BasicBlock pseudo_exit_block_;
+
   /// The constructs that are available in this function
   std::list<Construct> cfg_constructs_;
 
@@ -191,5 +207,4 @@ class Function {
 
 }  /// namespace libspirv
 
-
 #endif  /// LIBSPIRV_VAL_FUNCTION_H_
index 9002634..9bfab09 100644 (file)
@@ -42,6 +42,9 @@
 
 namespace libspirv {
 
+// Universal Limit of ResultID + 1
+static const uint32_t kInvalidId = 0x400000;
+
 // Info about a result ID.
 typedef struct spv_id_info_t {
   /// Id value.
index 6f2b89e..74b3506 100644 (file)
@@ -29,6 +29,7 @@
 
 #include <algorithm>
 #include <array>
+#include <functional>
 #include <list>
 #include <map>
 #include <string>
@@ -52,16 +53,25 @@ namespace libspirv {
 
 class ValidationState_t;
 
-/// @brief Calculates dominator edges of a root basic block
+/// A function that returns a vector of BasicBlocks given a BasicBlock. Used to
+/// get the successor and predecessor nodes of a CFG block
+using get_blocks_func =
+    std::function<const std::vector<BasicBlock*>*(const BasicBlock*)>;
+
+/// @brief Calculates dominator edges for a set of blocks
 ///
-/// This function calculates the dominator edges form a root BasicBlock. Uses
-/// the dominator algorithm by Cooper et al.
+/// This function calculates the dominator edges for a set of blocks in the CFG.
+/// Uses the dominator algorithm by Cooper et al.
 ///
-/// @param[in] first_block the root or entry BasicBlock of a function
+/// @param[in] postorder        A vector of blocks in post order traversal order
+///                             in a CFG
+/// @param[in] predecessor_func Function used to get the predecessor nodes of a
+///                             block
 ///
 /// @return a set of dominator edges represented as a pair of blocks
 std::vector<std::pair<BasicBlock*, BasicBlock*>> CalculateDominators(
-    const BasicBlock& first_block);
+    const std::vector<const BasicBlock*>& postorder,
+    get_blocks_func predecessor_func);
 
 /// @brief Performs the Control Flow Graph checks
 ///
@@ -76,8 +86,11 @@ spv_result_t PerformCfgChecks(ValidationState_t& _);
 /// provided by the @p dom_edges parameter
 ///
 /// @param[in,out] dom_edges The edges of the dominator tree
+/// @param[in] set_func This function will be called to updated the Immediate
+///                     dominator
 void UpdateImmediateDominators(
-    std::vector<std::pair<BasicBlock*, BasicBlock*>>& dom_edges);
+    const std::vector<std::pair<BasicBlock*, BasicBlock*>>& dom_edges,
+    std::function<void(BasicBlock*, BasicBlock*)> set_func);
 
 /// @brief Prints all of the dominators of a BasicBlock
 ///
index a1e86bb..b687661 100644 (file)
@@ -30,6 +30,8 @@
 
 #include <algorithm>
 #include <functional>
+#include <set>
+#include <string>
 #include <unordered_map>
 #include <unordered_set>
 #include <utility>
 using std::find;
 using std::function;
 using std::get;
+using std::ignore;
 using std::make_pair;
 using std::numeric_limits;
 using std::pair;
+using std::set;
+using std::string;
+using std::tie;
 using std::transform;
 using std::unordered_map;
 using std::unordered_set;
@@ -61,8 +67,6 @@ using bb_ptr = BasicBlock*;
 using cbb_ptr = const BasicBlock*;
 using bb_iter = vector<BasicBlock*>::const_iterator;
 
-using get_blocks_func = function<const vector<BasicBlock*>*(const BasicBlock*)>;
-
 struct block_info {
   cbb_ptr block;  ///< pointer to the block
   bb_iter iter;   ///< Iterator to the current child node being processed
@@ -92,8 +96,8 @@ bool FindInWorkList(const vector<block_info>& work_list, uint32_t id) {
 /// @param[in] entry The root BasicBlock of a CFG tree
 /// @param[in] successor_func  A function which will return a pointer to the
 ///                            successor nodes
-/// @param[in] preorder   A function that will be called for every block in a CFG
-///                       following preorder traversal semantics
+/// @param[in] preorder   A function that will be called for every block in a
+///                       CFG following preorder traversal semantics
 /// @param[in] postorder  A function that will be called for every block in a
 ///                       CFG following postorder traversal semantics
 /// @param[in] backedge   A function that will be called when a backedge is
@@ -143,45 +147,44 @@ const vector<BasicBlock*>* successor(const BasicBlock* b) {
   return b->get_successors();
 }
 
+const vector<BasicBlock*>* predecessor(const BasicBlock* b) {
+  return b->get_predecessors();
+}
+
 }  // namespace
 
 vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
-    vector<cbb_ptr>& postorder) {
+    const vector<cbb_ptr>& postorder, get_blocks_func predecessor_func) {
   struct block_detail {
     size_t dominator;  ///< The index of blocks's dominator in post order array
     size_t postorder_index;  ///< The index of the block in the post order array
   };
-
-  const size_t undefined_dom = static_cast<size_t>(postorder.size());
+  const size_t undefined_dom = postorder.size();
 
   unordered_map<cbb_ptr, block_detail> idoms;
   for (size_t i = 0; i < postorder.size(); i++) {
     idoms[postorder[i]] = {undefined_dom, i};
   }
-
   idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index;
 
   bool changed = true;
   while (changed) {
     changed = false;
     for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) {
-      size_t& b_dom = idoms[*b].dominator;
-      const vector<BasicBlock*>* predecessors = (*b)->get_predecessors();
-
-      // first processed predecessor
+      const vector<BasicBlock*>* predecessors = predecessor_func(*b);
+      // first processed/reachable predecessor
       auto res = find_if(begin(*predecessors), end(*predecessors),
                          [&idoms, undefined_dom](BasicBlock* pred) {
-                           return idoms[pred].dominator != undefined_dom;
+                           return idoms[pred].dominator != undefined_dom &&
+                                  pred->is_reachable();
                          });
-      assert(res != end(*predecessors));
+      if (res == end(*predecessors)) continue;
       BasicBlock* idom = *res;
       size_t idom_idx = idoms[idom].postorder_index;
 
       // all other predecessors
       for (auto p : *predecessors) {
-        if (idom == p || p->is_reachable() == false) {
-          continue;
-        }
+        if (idom == p || p->is_reachable() == false) continue;
         if (idoms[p].dominator != undefined_dom) {
           size_t finger1 = idoms[p].postorder_index;
           size_t finger2 = idom_idx;
@@ -196,8 +199,8 @@ vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
           idom_idx = finger1;
         }
       }
-      if (b_dom != idom_idx) {
-        b_dom = idom_idx;
+      if (idoms[*b].dominator != idom_idx) {
+        idoms[*b].dominator = idom_idx;
         changed = true;
       }
     }
@@ -213,13 +216,15 @@ vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
   return out;
 }
 
-void UpdateImmediateDominators(vector<pair<bb_ptr, bb_ptr>>& dom_edges) {
+void UpdateImmediateDominators(
+    const vector<pair<bb_ptr, bb_ptr>>& dom_edges,
+    function<void(BasicBlock*, BasicBlock*)> set_func) {
   for (auto& edge : dom_edges) {
-    get<0>(edge)->SetImmediateDominator(get<1>(edge));
+    set_func(get<0>(edge), get<1>(edge));
   }
 }
 
-void printDominatorList(BasicBlock& b) {
+void printDominatorList(const BasicBlock& b) {
   std::cout << b.get_id() << " is dominated by: ";
   const BasicBlock* bb = &b;
   while (bb->GetImmediateDominator() != bb) {
@@ -244,7 +249,7 @@ spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
 }
 
 spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
-  if (_.get_current_function().IsMergeBlock(merge_block)) {
+  if (_.get_current_function().IsBlockType(merge_block, kBlockTypeMerge)) {
     return _.diag(SPV_ERROR_INVALID_CFG)
            << "Block " << _.getIdName(merge_block)
            << " is already a merge block for another header";
@@ -252,21 +257,188 @@ spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
   return SPV_SUCCESS;
 }
 
+/// Update the continue construct's exit blocks once the backedge blocks are
+/// identified in the CFG.
+void UpdateContinueConstructExitBlocks(
+    Function& function, const vector<pair<uint32_t, uint32_t>>& back_edges) {
+  auto& constructs = function.get_constructs();
+  // TODO(umar): Think of a faster way to do this
+  for (auto& edge : back_edges) {
+    uint32_t back_edge_block_id;
+    uint32_t loop_header_block_id;
+    tie(back_edge_block_id, loop_header_block_id) = edge;
+
+    auto is_this_header = [=](Construct& c) {
+      return c.get_type() == ConstructType::kLoop &&
+             c.get_entry()->get_id() == loop_header_block_id;
+    };
+
+    for (auto construct : constructs) {
+      if (is_this_header(construct)) {
+        Construct* continue_construct =
+            construct.get_corresponding_constructs().back();
+        assert(continue_construct->get_type() == ConstructType::kContinue);
+
+        BasicBlock* back_edge_block;
+        tie(back_edge_block, ignore) = function.GetBlock(back_edge_block_id);
+        continue_construct->set_exit(back_edge_block);
+      }
+    }
+  }
+}
+
+/// Constructs an error message for construct validation errors
+string ConstructErrorString(const Construct& construct,
+                            const string& header_string,
+                            const string& exit_string,
+                            bool post_dominate = false) {
+  string construct_name;
+  string header_name;
+  string exit_name;
+  string dominate_text;
+  if (post_dominate) {
+    dominate_text = "is not post dominated by";
+  } else {
+    dominate_text = "does not dominate";
+  }
+
+  switch (construct.get_type()) {
+    case ConstructType::kSelection:
+      construct_name = "selection";
+      header_name = "selection header";
+      exit_name = "merge block";
+      break;
+    case ConstructType::kLoop:
+      construct_name = "loop";
+      header_name = "loop header";
+      exit_name = "merge block";
+      break;
+    case ConstructType::kContinue:
+      construct_name = "continue";
+      header_name = "continue target";
+      exit_name = "back-edge block";
+      break;
+    case ConstructType::kCase:
+      construct_name = "case";
+      header_name = "case block";
+      exit_name = "exit block";  // TODO(umar): there has to be a better name
+      break;
+    default:
+      assert(1 == 0 && "Not defined type");
+  }
+  // TODO(umar): Add header block for continue constructs to error message
+  return "The " + construct_name + " construct with the " + header_name + " " +
+         header_string + " " + dominate_text + " the " + exit_name + " " +
+         exit_string;
+}
+
+spv_result_t StructuredControlFlowChecks(
+    const ValidationState_t& _, const Function& function,
+    const vector<pair<uint32_t, uint32_t>>& back_edges) {
+  /// Check all backedges target only loop headers and have exactly one
+  /// back-edge branching to it
+  set<uint32_t> loop_headers;
+  for (auto back_edge : back_edges) {
+    uint32_t back_edge_block;
+    uint32_t header_block;
+    tie(back_edge_block, header_block) = back_edge;
+    if (!function.IsBlockType(header_block, kBlockTypeLoop)) {
+      return _.diag(SPV_ERROR_INVALID_CFG)
+             << "Back-edges (" << _.getIdName(back_edge_block) << " -> "
+             << _.getIdName(header_block)
+             << ") can only be formed between a block and a loop header.";
+    }
+    bool success;
+    tie(ignore, success) = loop_headers.insert(header_block);
+    if (!success) {
+      // TODO(umar): List the back-edge blocks that are branching to loop
+      // header
+      return _.diag(SPV_ERROR_INVALID_CFG)
+             << "Loop header " << _.getIdName(header_block)
+             << " targeted by multiple back-edges";
+    }
+  }
+
+  // Check construct rules
+  for (const Construct& construct : function.get_constructs()) {
+    auto header = construct.get_entry();
+    auto merge = construct.get_exit();
+
+    // if the merge block is reachable then it's dominated by the header
+    if (merge->is_reachable() &&
+        find(merge->dom_begin(), merge->dom_end(), header) ==
+            merge->dom_end()) {
+      return _.diag(SPV_ERROR_INVALID_CFG)
+             << ConstructErrorString(construct, _.getIdName(header->get_id()),
+                                     _.getIdName(merge->get_id()));
+    }
+    if (construct.get_type() == ConstructType::kContinue) {
+      if (find(header->pdom_begin(), header->pdom_end(), merge) ==
+          merge->pdom_end()) {
+        return _.diag(SPV_ERROR_INVALID_CFG)
+               << ConstructErrorString(construct, _.getIdName(header->get_id()),
+                                       _.getIdName(merge->get_id()), true);
+      }
+    }
+    // TODO(umar):  an OpSwitch block dominates all its defined case
+    // constructs
+    // TODO(umar):  each case construct has at most one branch to another
+    // case construct
+    // TODO(umar):  each case construct is branched to by at most one other
+    // case construct
+    // TODO(umar):  if Target T1 branches to Target T2, or if Target T1
+    // branches to the Default and the Default branches to Target T2, then
+    // T1 must immediately precede T2 in the list of the OpSwitch Target
+    // operands
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t PerformCfgChecks(ValidationState_t& _) {
   for (auto& function : _.get_functions()) {
+    // Check all referenced blocks are defined within a function
+    if (function.get_undefined_block_count() != 0) {
+      string undef_blocks("{");
+      for (auto undefined_block : function.get_undefined_blocks()) {
+        undef_blocks += _.getIdName(undefined_block) + " ";
+      }
+      return _.diag(SPV_ERROR_INVALID_CFG)
+             << "Block(s) " << undef_blocks << "\b}"
+             << " are referenced but not defined in function "
+             << _.getIdName(function.get_id());
+    }
+
     // Updates each blocks immediate dominators
     vector<const BasicBlock*> postorder;
+    vector<const BasicBlock*> postdom_postorder;
     vector<pair<uint32_t, uint32_t>> back_edges;
     if (auto* first_block = function.get_first_block()) {
+      /// calculate dominators
       DepthFirstTraversal(*first_block, successor, [](cbb_ptr) {},
                           [&](cbb_ptr b) { postorder.push_back(b); },
                           [&](cbb_ptr from, cbb_ptr to) {
                             back_edges.emplace_back(from->get_id(),
                                                     to->get_id());
                           });
-      auto edges = libspirv::CalculateDominators(postorder);
-      libspirv::UpdateImmediateDominators(edges);
+      auto edges = libspirv::CalculateDominators(postorder, predecessor);
+      libspirv::UpdateImmediateDominators(
+          edges, [](bb_ptr block, bb_ptr dominator) {
+            block->SetImmediateDominator(dominator);
+          });
+
+      /// calculate post dominators
+      auto exit_block = function.get_pseudo_exit_block();
+      DepthFirstTraversal(*exit_block, predecessor, [](cbb_ptr) {},
+                          [&](cbb_ptr b) { postdom_postorder.push_back(b); },
+                          [&](cbb_ptr, cbb_ptr) {});
+      auto postdom_edges =
+          libspirv::CalculateDominators(postdom_postorder, successor);
+      libspirv::UpdateImmediateDominators(
+          postdom_edges, [](bb_ptr block, bb_ptr dominator) {
+            block->SetImmediatePostDominator(dominator);
+          });
     }
+    UpdateContinueConstructExitBlocks(function, back_edges);
 
     // Check if the order of blocks in the binary appear before the blocks they
     // dominate
@@ -284,41 +456,10 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
       }
     }
 
-    // Check all referenced blocks are defined within a function
-    if (function.get_undefined_block_count() != 0) {
-      std::stringstream ss;
-      ss << "{";
-      for (auto undefined_block : function.get_undefined_blocks()) {
-        ss << _.getIdName(undefined_block) << " ";
-      }
-      return _.diag(SPV_ERROR_INVALID_CFG)
-             << "Block(s) " << ss.str() << "\b}"
-             << " are referenced but not defined in function "
-             << _.getIdName(function.get_id());
+    /// Structured control flow checks are only required for shader capabilities
+    if (_.hasCapability(SpvCapabilityShader)) {
+      spvCheckReturn(StructuredControlFlowChecks(_, function, back_edges));
     }
-
-    // Check all headers dominate their merge blocks
-    for (Construct& construct : function.get_constructs()) {
-      auto header = construct.get_header();
-      auto merge = construct.get_merge();
-      // auto cont = construct.get_continue();
-
-      if (merge->is_reachable() &&
-          find(merge->dom_begin(), merge->dom_end(), header) ==
-              merge->dom_end()) {
-        return _.diag(SPV_ERROR_INVALID_CFG)
-               << "Header block " << _.getIdName(header->get_id())
-               << " doesn't dominate its merge block "
-               << _.getIdName(merge->get_id());
-      }
-    }
-
-    // TODO(umar): All CFG back edges must branch to a loop header, with each
-    // loop header having exactly one back edge branching to it
-
-    // TODO(umar): For a given loop, its back-edge block must post dominate the
-    // OpLoopMerge's Continue Target, and that Continue Target must dominate the
-    // back-edge block
   }
   return SPV_SUCCESS;
 }
@@ -331,7 +472,6 @@ spv_result_t CfgPass(ValidationState_t& _,
       spvCheckReturn(_.get_current_function().RegisterBlock(inst->result_id));
       break;
     case SpvOpLoopMerge: {
-      // TODO(umar): mark current block as a loop header
       uint32_t merge_block = inst->words[inst->operands[0].offset];
       uint32_t continue_block = inst->words[inst->operands[1].offset];
       CFG_ASSERT(MergeBlockAssert, merge_block);
index 28229cc..5a713c8 100644 (file)
@@ -56,7 +56,7 @@ using ::testing::MatchesRegex;
 using libspirv::BasicBlock;
 using libspirv::ValidationState_t;
 
-using ValidateCFG = spvtest::ValidateBase<bool>;
+using ValidateCFG = spvtest::ValidateBase<SpvCapability>;
 using spvtest::ScopedContext;
 
 namespace {
@@ -160,34 +160,52 @@ Block& operator>>(Block& lhs, Block& successor) {
   return lhs;
 }
 
-string header =
-    "OpCapability Shader\n"
-    "OpMemoryModel Logical GLSL450\n";
+const char* header(SpvCapability cap) {
+  static const char* shader_header =
+      "OpCapability Shader\n"
+      "OpMemoryModel Logical GLSL450\n";
 
-string types_consts =
-    "%voidt   = OpTypeVoid\n"
-    "%boolt   = OpTypeBool\n"
-    "%intt    = OpTypeInt 32 1\n"
-    "%one     = OpConstant %intt 1\n"
-    "%two     = OpConstant %intt 2\n"
-    "%ptrt    = OpTypePointer Function %intt\n"
-    "%funct   = OpTypeFunction %voidt\n";
+  static const char* kernel_header =
+      "OpCapability Kernel\n"
+      "OpMemoryModel Logical OpenCL\n";
 
-TEST_F(ValidateCFG, Simple) {
-  Block first("first");
+  return (cap == SpvCapabilityShader) ? shader_header : kernel_header;
+}
+
+const char* types_consts() {
+  static const char* types =
+      "%voidt   = OpTypeVoid\n"
+      "%boolt   = OpTypeBool\n"
+      "%intt    = OpTypeInt 32 1\n"
+      "%one     = OpConstant %intt 1\n"
+      "%two     = OpConstant %intt 2\n"
+      "%ptrt    = OpTypePointer Function %intt\n"
+      "%funct   = OpTypeFunction %voidt\n";
+
+  return types;
+}
+
+INSTANTIATE_TEST_CASE_P(StructuredControlFlow, ValidateCFG,
+                        ::testing::Values(SpvCapabilityShader,
+                                          SpvCapabilityKernel));
+
+TEST_P(ValidateCFG, Simple) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
   Block loop("loop", SpvOpBranchConditional);
   Block cont("cont");
   Block merge("merge", SpvOpReturn);
 
-  loop.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpLoopMerge %merge %cont None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) {
+    loop.setBody("OpLoopMerge %merge %cont None\n");
+  }
 
-  string str = header + nameOps("loop", "first", "cont", "merge",
-                                make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + nameOps("loop", "entry", "cont", "merge",
+                                            make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
-  str += first >> loop;
+  str += entry >> loop;
   str += loop >> vector<Block>({cont, merge});
   str += cont >> loop;
   str += merge;
@@ -197,15 +215,15 @@ TEST_F(ValidateCFG, Simple) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, Variable) {
+TEST_P(ValidateCFG, Variable) {
   Block entry("entry");
   Block cont("cont");
   Block exit("exit", SpvOpReturn);
 
   entry.setBody("%var = OpVariable %ptrt Function\n");
 
-  string str = header + nameOps(make_pair("func", "Main")) + types_consts +
-               " %func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + nameOps(make_pair("func", "Main")) +
+               types_consts() + " %func    = OpFunction %voidt None %funct\n";
   str += entry >> cont;
   str += cont >> exit;
   str += exit;
@@ -215,7 +233,7 @@ TEST_F(ValidateCFG, Variable) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
+TEST_P(ValidateCFG, VariableNotInFirstBlockBad) {
   Block entry("entry");
   Block cont("cont");
   Block exit("exit", SpvOpReturn);
@@ -223,8 +241,8 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
   // This operation should only be performed in the entry block
   cont.setBody("%var = OpVariable %ptrt Function\n");
 
-  string str = header + nameOps(make_pair("func", "Main")) + types_consts +
-               " %func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + nameOps(make_pair("func", "Main")) +
+               types_consts() + " %func    = OpFunction %voidt None %funct\n";
 
   str += entry >> cont;
   str += cont >> exit;
@@ -239,18 +257,19 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
           "Variables can only be defined in the first block of a function"));
 }
 
-TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) {
+TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block cont("cont");
   Block branch("branch", SpvOpBranchConditional);
   Block merge("merge", SpvOpReturn);
 
-  branch.setBody(
-      " %cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %merge None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("cont", "branch", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("cont", "branch", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> branch;
   str += cont >> merge;  // cont appears before its dominator
@@ -265,20 +284,22 @@ TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) {
                            "before its dominator .\\[branch\\]"));
 }
 
-TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
+TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block loop("loop");
   Block selection("selection", SpvOpBranchConditional);
   Block merge("merge", SpvOpReturn);
 
-  loop.setBody(
-      " %cond   = OpSLessThan %intt %one %two\n"
-      " OpLoopMerge %merge %loop None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n");
+
   // cannot share the same merge
-  selection.setBody("OpSelectionMerge %merge None\n");
+  if (is_shader) selection.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("merge", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("merge", make_pair("func", "Main")) + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
   str += loop >> selection;
@@ -287,26 +308,32 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
   str += "OpFunctionEnd\n";
 
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("Block .\\[merge\\] is already a merge block "
-                           "for another header"));
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("Block .\\[merge\\] is already a merge block "
+                             "for another header"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
 }
 
-TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
+TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block loop("loop", SpvOpBranchConditional);
   Block selection("selection", SpvOpBranchConditional);
   Block merge("merge", SpvOpReturn);
 
-  selection.setBody(
-      " %cond   = OpSLessThan %intt %one %two\n"
-      " OpSelectionMerge %merge None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) selection.setBody(" OpSelectionMerge %merge None\n");
+
   // cannot share the same merge
-  loop.setBody(" OpLoopMerge %merge %loop None\n");
+  if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n");
 
-  string str = header + nameOps("merge", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("merge", make_pair("func", "Main")) + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> selection;
   str += selection >> vector<Block>({merge, loop});
@@ -315,18 +342,23 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
   str += "OpFunctionEnd\n";
 
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("Block .\\[merge\\] is already a merge block "
-                           "for another header"));
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("Block .\\[merge\\] is already a merge block "
+                             "for another header"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
 }
 
-TEST_F(ValidateCFG, BranchTargetFirstBlockBad) {
+TEST_P(ValidateCFG, BranchTargetFirstBlockBad) {
   Block entry("entry");
   Block bad("bad");
   Block end("end", SpvOpReturn);
-  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> bad;
   str += bad >> entry;  // Cannot target entry block
@@ -340,17 +372,17 @@ TEST_F(ValidateCFG, BranchTargetFirstBlockBad) {
                            "is targeted by block .\\[bad\\]"));
 }
 
-TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
+TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
   Block entry("entry");
   Block bad("bad", SpvOpBranchConditional);
   Block exit("exit", SpvOpReturn);
 
-  bad.setBody(
-      " %cond    = OpSLessThan %intt %one %two\n"
-      " OpLoopMerge %entry %exit None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  bad.setBody(" OpLoopMerge %entry %exit None\n");
 
-  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> bad;
   str += bad >> vector<Block>({entry, exit});  // cannot target entry block
@@ -364,19 +396,19 @@ TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
                            "is targeted by block .\\[bad\\]"));
 }
 
-TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
+TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
   Block entry("entry");
   Block bad("bad", SpvOpBranchConditional);
   Block t("t");
   Block merge("merge");
   Block end("end", SpvOpReturn);
 
-  bad.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpLoopMerge %merge %cont None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  bad.setBody("OpLoopMerge %merge %cont None\n");
 
-  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> bad;
   str += bad >> vector<Block>({t, entry});
@@ -391,7 +423,7 @@ TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
                            "is targeted by block .\\[bad\\]"));
 }
 
-TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
+TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) {
   Block entry("entry");
   Block bad("bad", SpvOpSwitch);
   Block block1("block1");
@@ -401,12 +433,12 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
   Block merge("merge");
   Block end("end", SpvOpReturn);
 
-  bad.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %merge None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  bad.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> bad;
   str += bad >> vector<Block>({def, block1, block2, block3, entry});
@@ -425,21 +457,21 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
                            "is targeted by block .\\[bad\\]"));
 }
 
-TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) {
+TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) {
   Block entry("entry");
   Block middle("middle", SpvOpBranchConditional);
   Block end("end", SpvOpReturn);
 
-  middle.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %end None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  middle.setBody("OpSelectionMerge %end None\n");
 
   Block entry2("entry2");
   Block middle2("middle2");
   Block end2("end2", SpvOpReturn);
 
-  string str = header + nameOps("middle2", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("middle2", make_pair("func", "Main")) + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> middle;
   str += middle >> vector<Block>({end, middle2});
@@ -460,7 +492,8 @@ TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) {
                    "defined in function .\\[Main\\]"));
 }
 
-TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
+TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block head("head", SpvOpBranchConditional);
   Block f("f");
@@ -468,10 +501,11 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
 
   entry.setBody("%cond = OpSLessThan %intt %one %two\n");
 
-  head.setBody("OpSelectionMerge %merge None\n");
+  if (is_shader) head.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("head", "merge", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("head", "merge", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> merge;
   str += head >> vector<Block>({merge, f});
@@ -479,26 +513,33 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
   str += merge;
 
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
-  EXPECT_THAT(
-      getDiagnosticString(),
-      MatchesRegex("Header block .\\[head\\] doesn't dominate its merge block "
-                   ".\\[merge\\]"));
+
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(
+        getDiagnosticString(),
+        MatchesRegex("The selection construct with the selection header "
+                     ".\\[head\\] does not dominate the merge block "
+                     ".\\[merge\\]"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
 }
 
-TEST_F(ValidateCFG, UnreachableMerge) {
+TEST_P(ValidateCFG, UnreachableMerge) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block branch("branch", SpvOpBranchConditional);
   Block t("t", SpvOpReturn);
   Block f("f", SpvOpReturn);
   Block merge("merge", SpvOpReturn);
 
-  branch.setBody(
-      " %cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %merge None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("branch", "merge", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("branch", "merge", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> branch;
   str += branch >> vector<Block>({t, f});
@@ -511,19 +552,20 @@ TEST_F(ValidateCFG, UnreachableMerge) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
+TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block branch("branch", SpvOpBranchConditional);
   Block t("t", SpvOpReturn);
   Block f("f", SpvOpReturn);
   Block merge("merge", SpvOpUnreachable);
 
-  branch.setBody(
-      " %cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %merge None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
 
-  string str = header + nameOps("branch", "merge", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("branch", "merge", make_pair("func", "Main")) +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> branch;
   str += branch >> vector<Block>({t, f});
@@ -536,14 +578,14 @@ TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, UnreachableBlock) {
+TEST_P(ValidateCFG, UnreachableBlock) {
   Block entry("entry");
   Block unreachable("unreachable");
   Block exit("exit", SpvOpReturn);
 
-  string str = header +
+  string str = header(GetParam()) +
                nameOps("unreachable", "exit", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> exit;
   str += unreachable >> exit;
@@ -554,7 +596,8 @@ TEST_F(ValidateCFG, UnreachableBlock) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, UnreachableBranch) {
+TEST_P(ValidateCFG, UnreachableBranch) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block unreachable("unreachable", SpvOpBranchConditional);
   Block unreachablechildt("unreachablechildt");
@@ -562,12 +605,11 @@ TEST_F(ValidateCFG, UnreachableBranch) {
   Block merge("merge");
   Block exit("exit", SpvOpReturn);
 
-  unreachable.setBody(
-      " %cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %merge None\n");
-  string str = header +
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) unreachable.setBody("OpSelectionMerge %merge None\n");
+  string str = header(GetParam()) +
                nameOps("unreachable", "exit", make_pair("func", "Main")) +
-               types_consts + "%func    = OpFunction %voidt None %funct\n";
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> exit;
   str += unreachable >> vector<Block>({unreachablechildt, unreachablechildf});
@@ -581,25 +623,25 @@ TEST_F(ValidateCFG, UnreachableBranch) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, EmptyFunction) {
-  string str = header + types_consts +
+TEST_P(ValidateCFG, EmptyFunction) {
+  string str = header(GetParam()) + string(types_consts()) +
                "%func    = OpFunction %voidt None %funct\n" + "OpFunctionEnd\n";
 
   CompileSuccessfully(str);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, SingleBlockLoop) {
+TEST_P(ValidateCFG, SingleBlockLoop) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block loop("loop", SpvOpBranchConditional);
   Block exit("exit", SpvOpReturn);
 
-  loop.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpLoopMerge %exit %loop None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody("OpLoopMerge %exit %loop None\n");
 
-  string str =
-      header + types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + string(types_consts()) +
+               "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
   str += loop >> vector<Block>({loop, exit});
@@ -610,7 +652,8 @@ TEST_F(ValidateCFG, SingleBlockLoop) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, NestedLoops) {
+TEST_P(ValidateCFG, NestedLoops) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block loop1("loop1");
   Block loop1_cont_break_block("loop1_cont_break_block",
@@ -620,14 +663,14 @@ TEST_F(ValidateCFG, NestedLoops) {
   Block loop1_merge("loop1_merge");
   Block exit("exit", SpvOpReturn);
 
-  loop1.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpLoopMerge %loop1_merge %loop2 None\n");
-
-  loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) {
+    loop1.setBody("OpLoopMerge %loop1_merge %loop2 None\n");
+    loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
+  }
 
-  string str =
-      header + types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + nameOps("loop2", "loop2_merge") +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop1;
   str += loop1 >> loop1_cont_break_block;
@@ -641,29 +684,33 @@ TEST_F(ValidateCFG, NestedLoops) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, NestedSelection) {
+TEST_P(ValidateCFG, NestedSelection) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   const int N = 256;
   vector<Block> if_blocks;
   vector<Block> merge_blocks;
   Block inner("inner");
 
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+
   if_blocks.emplace_back("if0", SpvOpBranchConditional);
-  if_blocks[0].setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpSelectionMerge %if_merge0 None\n");
+
+  if (is_shader) if_blocks[0].setBody("OpSelectionMerge %if_merge0 None\n");
   merge_blocks.emplace_back("if_merge0", SpvOpReturn);
 
   for (int i = 1; i < N; i++) {
     stringstream ss;
     ss << i;
     if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional);
-    if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n");
+    if (is_shader)
+      if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n");
     merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch);
   }
-  string str =
-      header + types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) + string(types_consts()) +
+               "%func    = OpFunction %voidt None %funct\n";
 
+  str += entry >> if_blocks[0];
   for (int i = 0; i < N - 1; i++) {
     str += if_blocks[i] >> vector<Block>({if_blocks[i + 1], merge_blocks[i]});
   }
@@ -679,37 +726,282 @@ TEST_F(ValidateCFG, NestedSelection) {
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-// TODO(umar): enable this test
-TEST_F(ValidateCFG, DISABLED_BackEdgeBlockDoesntPostDominateContinueTargetBad) {
+TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
   Block entry("entry");
   Block loop1("loop1", SpvOpBranchConditional);
   Block loop2("loop2", SpvOpBranchConditional);
-  Block loop2_merge("loop2_merge");
-  Block loop1_merge("loop1_merge", SpvOpBranchConditional);
+  Block loop2_merge("loop2_merge", SpvOpBranchConditional);
+  Block be_block("be_block");
   Block exit("exit", SpvOpReturn);
 
-  loop1.setBody(
-      "%cond    = OpSLessThan %intt %one %two\n"
-      "OpLoopMerge %loop1_merge %loop2 None\n");
-
-  loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) {
+    loop1.setBody("OpLoopMerge %exit %loop2_merge None\n");
+    loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
+  }
 
-  string str =
-      header + types_consts + "%func    = OpFunction %voidt None %funct\n";
+  string str = header(GetParam()) +
+               nameOps("loop1", "loop2", "be_block", "loop2_merge") +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop1;
-  str += loop1 >> vector<Block>({loop2, loop1_merge});
+  str += loop1 >> vector<Block>({loop2, exit});
   str += loop2 >> vector<Block>({loop2, loop2_merge});
-  str += loop2_merge >> loop1_merge;
-  str += loop1_merge >> vector<Block>({loop1, exit});
+  str += loop2_merge >> vector<Block>({be_block, exit});
+  str += be_block >> loop1;
   str += exit;
   str += "OpFunctionEnd";
 
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  if (GetParam() == SpvCapabilityShader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("The continue construct with the continue target "
+                             ".\\[loop2_merge\\] is not post dominated by the "
+                             "back-edge block .\\[be_block\\]"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block split("split", SpvOpBranchConditional);
+  Block t("t");
+  Block f("f");
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) split.setBody("OpSelectionMerge %exit None\n");
+
+  string str = header(GetParam()) + nameOps("split", "f") + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> split;
+  str += split >> vector<Block>({t, f});
+  str += t >> exit;
+  str += f >> split;
+  str += exit;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(
+        getDiagnosticString(),
+        MatchesRegex("Back-edges \\(.\\[f\\] -> .\\[split\\]\\) can only "
+                     "be formed between a block and a loop header."));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block split("split", SpvOpBranchConditional);
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) split.setBody("OpSelectionMerge %exit None\n");
+
+  string str = header(GetParam()) + nameOps("split") + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> split;
+  str += split >> vector<Block>({split, exit});
+  str += exit;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex(
+                    "Back-edges \\(.\\[split\\] -> .\\[split\\]\\) can only be "
+                    "formed between a block and a loop header."));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, MultipleBackEdgesToLoopHeaderBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block loop("loop", SpvOpBranchConditional);
+  Block cont("cont", SpvOpBranchConditional);
+  Block merge("merge", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
+
+  string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> loop;
+  str += loop >> vector<Block>({cont, merge});
+  str += cont >> vector<Block>({loop, loop});
+  str += merge;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex(
+                    "Loop header .\\[loop\\] targeted by multiple back-edges"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block loop("loop", SpvOpBranchConditional);
+  Block cheader("cheader", SpvOpBranchConditional);
+  Block be_block("be_block");
+  Block merge("merge", SpvOpReturn);
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody("OpLoopMerge %merge %cheader None\n");
+
+  string str = header(GetParam()) + nameOps("cheader", "be_block") +
+               types_consts() + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> loop;
+  str += loop >> vector<Block>({cheader, merge});
+  str += cheader >> vector<Block>({exit, be_block});
+  str += exit;  //  Branches out of a continue construct
+  str += be_block >> loop;
+  str += merge;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("The continue construct with the continue target "
+                             ".\\[cheader\\] is not post dominated by the "
+                             "back-edge block .\\[be_block\\]"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block loop("loop", SpvOpBranchConditional);
+  Block cont("cont", SpvOpBranchConditional);
+  Block merge("merge", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
+
+  string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> loop;
+  str += loop >> vector<Block>({cont, merge});
+  str += cont >> vector<Block>({loop, merge});
+  str += merge;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("The continue construct with the continue target "
+                             ".\\[loop\\] is not post dominated by the "
+                             "back-edge block .\\[cont\\]"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, BranchOutOfConstructBad) {
+  bool is_shader = GetParam() == SpvCapabilityShader;
+  Block entry("entry");
+  Block loop("loop", SpvOpBranchConditional);
+  Block cont("cont", SpvOpBranchConditional);
+  Block merge("merge");
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%cond    = OpSLessThan %intt %one %two\n");
+  if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
+
+  string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
+               "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> loop;
+  str += loop >> vector<Block>({cont, merge});
+  str += cont >> vector<Block>({loop, exit});
+  str += merge >> exit;
+  str += exit;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  if (is_shader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                MatchesRegex("The continue construct with the continue target "
+                             ".\\[loop\\] is not post dominated by the "
+                             "back-edge block .\\[cont\\]"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_F(ValidateCFG, OpSwitchToUnreachableBlock) {
+  Block entry("entry", SpvOpSwitch);
+  Block case0("case0");
+  Block case1("case1");
+  Block case2("case2");
+  Block def("default", SpvOpUnreachable);
+  Block phi("phi", SpvOpReturn);
+
+  string str = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main" %id
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 430
+OpName %main "main"
+OpDecorate %id BuiltIn GlobalInvocationId
+%void      = OpTypeVoid
+%voidf     = OpTypeFunction %void
+%u32       = OpTypeInt 32 0
+%f32       = OpTypeFloat 32
+%uvec3     = OpTypeVector %u32 3
+%fvec3     = OpTypeVector %f32 3
+%uvec3ptr  = OpTypePointer Input %uvec3
+%id        = OpVariable %uvec3ptr Input
+%one       = OpConstant %u32 1
+%three     = OpConstant %u32 3
+%main      = OpFunction %void None %voidf
+)";
+
+  entry.setBody(
+    "%idval    = OpLoad %uvec3 %id\n"
+    "%x        = OpCompositeExtract %u32 %idval 0\n"
+    "%selector = OpUMod %u32 %x %three\n"
+    "OpSelectionMerge %phi None\n");
+  str += entry >> vector<Block>({def, case0, case1, case2});
+  str += case1 >> phi;
+  str += def;
+  str += phi;
+  str += case0 >> phi;
+  str += case2 >> phi;
+  str += "OpFunctionEnd";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
 /// TODO(umar): Switch instructions
-/// TODO(umar): CFG branching outside of CFG construct
 /// TODO(umar): Nested CFG constructs
-}
+}  /// namespace
index 7467960..4092636 100644 (file)
@@ -90,4 +90,5 @@ template class spvtest::ValidateBase<
 template class spvtest::ValidateBase<
     std::tuple<int, std::tuple<std::string, std::function<spv_result_t(int)>,
                                std::function<spv_result_t(int)>>>>;
+template class spvtest::ValidateBase<SpvCapability>;
 }