Control Flow Graph Validation
authorUmar Arshad <umar@arrayfire.com>
Wed, 16 Mar 2016 21:20:02 +0000 (17:20 -0400)
committerDejan Mircevski <deki@google.com>
Thu, 2 Jun 2016 16:11:52 +0000 (12:11 -0400)
source/validate.cpp
source/validate.h
source/validate_cfg.cpp
source/validate_instruction.cpp
source/validate_layout.cpp
source/validate_types.cpp
test/CMakeLists.txt
test/UnitSPIRV.h
test/Validate.CFG.cpp [new file with mode: 0644]
test/ValidationState.cpp

index ab4de21..800cc51 100644 (file)
 // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
 
+#include <cassert>
+#include <cstdio>
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <sstream>
+#include <string>
+#include <vector>
+
 #include "validate.h"
 #include "validate_passes.h"
 
 #include "spirv_constant.h"
 #include "spirv_endian.h"
 
-#include <algorithm>
-#include <cassert>
-#include <cstdio>
-#include <functional>
-#include <iterator>
-#include <sstream>
-#include <string>
-#include <vector>
-
 using std::function;
 using std::ostream_iterator;
 using std::placeholders::_1;
@@ -184,7 +185,6 @@ spv_result_t spvValidate(const spv_const_context context,
   // TODO(umar): Add validation checks which require the parsing of the entire
   // module. Use the information from the ProcessInstruction pass to make the
   // checks.
-
   if (vstate.unresolvedForwardIdCount() > 0) {
     stringstream ss;
     vector<uint32_t> ids = vstate.unresolvedForwardIds();
@@ -198,6 +198,10 @@ spv_result_t spvValidate(const spv_const_context context,
            << id_str.substr(0, id_str.size() - 1);
   }
 
+  // CFG checks are performed after the binary has been parsed
+  // and the CFGPass has collected information about the control flow
+  spvCheckReturn(PerformCfgChecks(vstate));
+
   // NOTE: Copy each instruction for easier processing
   std::vector<spv_instruction_t> instructions;
   uint64_t index = SPV_INDEX_INSTRUCTION;
index ae0abb3..dd93a55 100644 (file)
@@ -28,6 +28,8 @@
 #define LIBSPIRV_VALIDATE_H_
 
 #include <algorithm>
+#include <array>
+#include <list>
 #include <map>
 #include <string>
 #include <unordered_map>
 #include "spirv_definition.h"
 #include "table.h"
 
+#define MSG(msg)                                        \
+  do {                                                  \
+    libspirv::message(__FILE__, size_t(__LINE__), msg); \
+  } while (0)
+
+#define SHOW(exp)                                               \
+  do {                                                          \
+    libspirv::message(__FILE__, size_t(__LINE__), #exp, (exp)); \
+  } while (0)
+
 // Structures
 
 // Info about a result ID.
@@ -59,9 +71,16 @@ typedef struct spv_id_info_t {
 
 namespace libspirv {
 
-// This enum represents the sections of a SPIRV module. See section 2.4
-// of the SPIRV spec for additional details of the order. The enumerant values
-// are in the same order as the vector returned by GetModuleOrder
+void message(std::string file, size_t line, std::string name);
+
+template <typename T>
+void message(std::string file, size_t line, std::string name, T val) {
+  std::cout << file << ":" << line << ": " << name << " " << val << std::endl;
+}
+
+/// This enum represents the sections of a SPIRV module. See section 2.4
+/// of the SPIRV spec for additional details of the order. The enumerant values
+/// are in the same order as the vector returned by GetModuleOrder
 enum ModuleLayoutSection {
   kLayoutCapabilities,          // < Section 2.4 #1
   kLayoutExtensions,            // < Section 2.4 #2
@@ -84,91 +103,298 @@ enum class FunctionDecl {
 };
 
 class ValidationState_t;
+class Function;
+
+// This class represents a basic block in a SPIR-V module
+class BasicBlock {
+ public:
+  /// Constructor for a BasicBlock
+  ///
+  /// @param[in] id The ID of the basic block
+  explicit BasicBlock(uint32_t id);
+
+  /// Returns the id of the BasicBlock
+  uint32_t get_id() const { return id_; }
+
+  /// Returns the predecessors of the BasicBlock
+  const std::vector<BasicBlock*>& get_predecessors() const {
+    return predecessors_;
+  }
+
+  /// Returns the predecessors of the BasicBlock
+  std::vector<BasicBlock*>& get_predecessors() { return predecessors_; }
+
+  /// Returns the successors of the BasicBlock
+  const std::vector<BasicBlock*>& get_successors() const { return successors_; }
+
+  /// Returns the successors of the BasicBlock
+  std::vector<BasicBlock*>& get_successors() { return successors_; }
+
+  /// Returns true if the  block should be reachable in the CFG
+  bool is_reachable() const { return reachable_; }
+
+  void set_reachability(bool reachability) { reachable_ = reachability; }
+
+  /// Sets the immedate dominator of this basic block
+  ///
+  /// @param[in] dom_block The dominator block
+  void SetImmediateDominator(BasicBlock* dom_block);
+
+  /// Returns the immedate dominator of this basic block
+  BasicBlock* GetImmediateDominator();
+
+  /// Returns the immedate dominator of this basic block
+  const BasicBlock* GetImmediateDominator() const;
+
+  /// Ends the block without a successor
+  void RegisterBranchInstruction(SpvOp branch_instruction);
+
+  /// Adds @p next BasicBlocks as successors of this BasicBlock
+  void RegisterSuccessors(std::vector<BasicBlock*> next = {});
+
+  /// Returns true if the id of the BasicBlock matches
+  bool operator==(const BasicBlock& other) const { return other.id_ == id_; }
+
+  /// Returns true if the id of the BasicBlock matches
+  bool operator==(const uint32_t& id) const { return id == id_; }
+
+  /// @brief A BasicBlock dominator iterator class
+  ///
+  /// This iterator will iterate over the dominators of the block
+  class DominatorIterator
+      : public std::iterator<std::forward_iterator_tag, BasicBlock*> {
+   public:
+    /// @brief Constructs the end of dominator iterator
+    ///
+    /// This will create an iterator which will represent the element
+    /// before the root node of the dominator tree
+    DominatorIterator();
+
+    /// @brief Constructs an iterator for the given block which points to
+    ///        @p block
+    ///
+    /// @param block The block which is referenced by the iterator
+    explicit DominatorIterator(const BasicBlock* block);
+
+    /// @brief Advances the iterator
+    DominatorIterator& operator++();
+
+    /// @brief Returns the current element
+    const BasicBlock*& operator*();
+
+    friend bool operator==(const DominatorIterator& lhs,
+                           const DominatorIterator& rhs);
+
+   private:
+    const BasicBlock* current_;
+  };
+
+  /// Returns an iterator which points to the current block
+  const DominatorIterator dom_begin() const;
+  DominatorIterator dom_begin();
+
+  /// Returns an iterator which points to one element past the first block
+  const DominatorIterator dom_end() const;
+  DominatorIterator dom_end();
+
+ private:
+  /// Id of the BasicBlock
+  const uint32_t id_;
+
+  /// Pointer to the immediate dominator of the BasicBlock
+  BasicBlock* immediate_dominator_;
+
+  /// The set of predecessors of the BasicBlock
+  std::vector<BasicBlock*> predecessors_;
+
+  /// The set of successors of the BasicBlock
+  std::vector<BasicBlock*> successors_;
+
+  SpvOp branch_instruction_;
+
+  bool reachable_;
+};
+
+/// @brief Returns true if the iterators point to the same element or if both
+///        iterators point to the @p dom_end block
+bool operator==(const BasicBlock::DominatorIterator& lhs,
+                const BasicBlock::DominatorIterator& rhs);
+
+/// @brief Returns true if the iterators point to different elements and they
+///        do not both point to the @p dom_end block
+bool operator!=(const BasicBlock::DominatorIterator& lhs,
+                const BasicBlock::DominatorIterator& rhs);
+
+/// @brief This class tracks the CFG constructs as defined in the SPIR-V spec
+class CFConstruct {
+  // Universal Limit of ResultID + 1
+  static const uint32_t kInitialValue = 0x400000;
+
+ public:
+  CFConstruct(BasicBlock* header_block, BasicBlock* merge_block,
+              BasicBlock* continue_block = nullptr)
+      : header_block_(header_block),
+        merge_block_(merge_block),
+        continue_block_(continue_block) {}
+
+  const BasicBlock* get_header() const { return header_block_; }
+  const BasicBlock* get_merge() const { return merge_block_; }
+  const BasicBlock* get_continue() const { return continue_block_; }
+
+  BasicBlock* get_header() { return header_block_; }
+  BasicBlock* get_merge() { return merge_block_; }
+  BasicBlock* get_continue() { return continue_block_; }
+
+ private:
+  BasicBlock* header_block_;    ///< The header block of a loop or selection
+  BasicBlock* merge_block_;     ///< The merge block of a loop or selection
+  BasicBlock* continue_block_;  ///< The continue block of a loop block
+};
 
 // This class manages all function declaration and definitions in a module. It
 // handles the state and id information while parsing a function in the SPIR-V
 // binary.
-//
-// NOTE: This class is designed to be a Structure of Arrays. Therefore each
-// member variable is a vector whose elements represent the values for the
-// corresponding function in a SPIR-V module. Variables that are not vector
-// types are used to manage the state while parsing the function.
-class Functions {
+class Function {
  public:
-  explicit Functions(ValidationState_t& module);
-
-  // Registers the function in the module. Subsequent instructions will be
-  // called against this function
-  spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id,
-                                uint32_t function_control,
-                                uint32_t function_type_id);
+  Function(uint32_t id, uint32_t result_type_id,
+           SpvFunctionControlMask function_control, uint32_t function_type_id,
+           ValidationState_t& module);
 
-  // Registers a function parameter in the current function
+  /// Registers a function parameter in the current function
+  /// @return Returns SPV_SUCCESS if the call was successful
   spv_result_t RegisterFunctionParameter(uint32_t id, uint32_t type_id);
 
-  // Register a function end instruction
-  spv_result_t RegisterFunctionEnd();
-
-  // Sets the declaration type of the current function
+  /// Sets the declaration type of the current function
+  /// @return Returns SPV_SUCCESS if the call was successful
   spv_result_t RegisterSetFunctionDeclType(FunctionDecl type);
 
   // Registers a block in the current function. Subsequent block instructions
   // will target this block
   // @param id The ID of the label of the block
-  spv_result_t RegisterBlock(uint32_t id);
-
-  // Registers a variable in the current block
+  /// @return Returns SPV_SUCCESS if the call was successful
+  spv_result_t RegisterBlock(uint32_t id, bool is_definition = true);
+
+  /// Registers a variable in the current block
+  ///
+  /// @param[in] type_id The type ID of the varaible
+  /// @param[in] id      The ID of the varaible
+  /// @param[in] storage The storage of the variable
+  /// @param[in] init_id The initializer ID of the variable
+  ///
+  /// @return Returns SPV_SUCCESS if the call was successful
   spv_result_t RegisterBlockVariable(uint32_t type_id, uint32_t id,
                                      SpvStorageClass storage, uint32_t init_id);
 
-  spv_result_t RegisterBlockLoopMerge(uint32_t merge_id, uint32_t continue_id,
-                                      SpvLoopControlMask control);
+  /// Registers a loop merge construct in the function
+  ///
+  /// @param[in] merge_id The merge block ID of the loop
+  /// @param[in] continue_id The continue block ID of the loop
+  ///
+  /// @return Returns SPV_SUCCESS if the call was successful
+  spv_result_t RegisterLoopMerge(uint32_t merge_id, uint32_t continue_id);
+
+  /// Registers a selection merge construct in the function
+  /// @return Returns SPV_SUCCESS if the call was successful
+  spv_result_t RegisterSelectionMerge(uint32_t merge_id);
+
+  /// Registers the end of the block
+  ///
+  /// @param[in] successors_list A list of ids to the blocks successors
+  /// @param[in] branch_instruction the branch instruction that ended the block
+  void RegisterBlockEnd(std::vector<uint32_t> successors_list,
+                        SpvOp branch_instruction);
+
+  /// Returns true if the \p merge_block_id is a merge block
+  bool IsMergeBlock(uint32_t merge_block_id) const;
 
-  spv_result_t RegisterBlockSelectionMerge(uint32_t merge_id,
-                                           SpvSelectionControlMask control);
+  /// Returns true if the \p id is the first block of this function
+  bool IsFirstBlock(uint32_t id) const;
 
-  // Registers the end of the block
-  spv_result_t RegisterBlockEnd();
+  /// Returns the first block of the current function
+  const BasicBlock* get_first_block() const;
+
+  /// Returns the first block of the current function
+  BasicBlock* get_first_block();
+
+  /// Returns a vector of all the blocks in the function
+  const std::vector<BasicBlock*>& get_blocks() const;
+
+  /// Returns a vector of all the blocks in the function
+  std::vector<BasicBlock*>& get_blocks();
+
+  /// Returns a list of all the cfg constructs in the function
+  const std::list<CFConstruct>& get_constructs() const;
+
+  /// Returns a list of all the cfg constructs in the function
+  std::list<CFConstruct>& get_constructs();
 
   // Returns the number of blocks in the current function being parsed
   size_t get_block_count() const;
 
-  // Returns true if called after a function instruction but before the
-  // function end instruction
-  bool in_function_body() const;
+  /// Returns the id of the funciton
+  uint32_t get_id() const { return id_; }
 
-  // Returns true if called after a label instruction but before a branch
-  // instruction
+  // Returns the number of blocks in the current function being parsed
+  size_t get_undefined_block_count() const;
+  const std::unordered_set<uint32_t>& get_undefined_blocks() const {
+    return undefined_blocks_;
+  }
+
+  /// Returns true if called after a label instruction but before a branch
+  /// instruction
   bool in_block() const;
 
-  libspirv::DiagnosticStream diag(spv_result_t error_code) const;
+  /// Returns the block that is currently being parsed in the binary
+  BasicBlock& get_current_block();
+
+  /// Returns the block that is currently being parsed in the binary
+  const BasicBlock& get_current_block() const;
+
+  /// Prints a GraphViz digraph of the CFG of the current funciton
+  void printDotGraph() const;
+
+  /// Prints a directed graph of the CFG of the current funciton
+  void printBlocks() const;
 
  private:
-  // Parent module
+  /// Parent module
   ValidationState_t& module_;
 
-  // Function IDs in a module
-  std::vector<uint32_t> id_;
+  /// The result id of the OpLabel that defined this block
+  uint32_t id_;
 
-  // OpTypeFunction IDs of each of the id_ functions
-  std::vector<uint32_t> type_id_;
+  /// The type of the function
+  uint32_t function_type_id_;
 
-  // The type of declaration of each function
-  std::vector<FunctionDecl> declaration_type_;
+  /// The type of the return value
+  uint32_t result_type_id_;
 
-  // TODO(umar): Probably needs better abstractions
-  // The beginning of the block of functions
-  std::vector<std::vector<uint32_t>> block_ids_;
+  /// The control fo the funciton
+  SpvFunctionControlMask function_control_;
 
-  // The variable IDs of the functions
-  std::vector<std::vector<uint32_t>> variable_ids_;
+  /// The type of declaration of each function
+  FunctionDecl declaration_type_;
 
-  // The function parameter ids of the functions
-  std::vector<std::vector<uint32_t>> parameter_ids_;
+  /// The blocks in the function mapped by block ID
+  std::unordered_map<uint32_t, BasicBlock> blocks_;
 
-  // NOTE: See correspoding getter functions
-  bool in_function_;
-  bool in_block_;
+  /// A list of blocks in the order they appeared in the binary
+  std::vector<BasicBlock*> ordered_blocks_;
+
+  /// Blocks which are forward referenced by blocks but not defined
+  std::unordered_set<uint32_t> undefined_blocks_;
+
+  /// The block that is currently being parsed
+  BasicBlock* current_block_;
+
+  /// The constructs that are available in this function
+  std::list<CFConstruct> cfg_constructs_;
+
+  /// The variable IDs of the functions
+  std::vector<uint32_t> variable_ids_;
+
+  /// The function parameter ids of the functions
+  std::vector<uint32_t> parameter_ids_;
 };
 
 class ValidationState_t {
@@ -190,6 +416,9 @@ class ValidationState_t {
   // the OpName instruction
   std::string getIdName(uint32_t id) const;
 
+  /// Like getIdName but does not display the id if the \p id has a name
+  std::string getIdOrName(uint32_t id) const;
+
   // Returns the number of ID which have been forward referenced but not defined
   size_t unresolvedForwardIdCount() const;
 
@@ -214,7 +443,10 @@ class ValidationState_t {
   libspirv::DiagnosticStream diag(spv_result_t error_code) const;
 
   // Returns the function states
-  Functions& get_functions();
+  std::list<Function>& get_functions();
+
+  // Returns the function states
+  Function& get_current_function();
 
   // Returns true if the called after a function instruction but before the
   // function end instruction
@@ -263,7 +495,16 @@ class ValidationState_t {
   const std::vector<uint32_t>& entry_points() const { return entry_points_; }
 
   // Registers the capability and its dependent capabilities
-  void registerCapability(SpvCapability cap);
+  void RegisterCapability(SpvCapability cap);
+
+  // Registers the function in the module. Subsequent instructions will be
+  // called against this function
+  spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id,
+                                SpvFunctionControlMask function_control,
+                                uint32_t function_type_id);
+
+  // Register a function end instruction
+  spv_result_t RegisterFunctionEnd();
 
   // Returns true if the capability is enabled in the module.
   bool hasCapability(SpvCapability cap) const;
@@ -299,9 +540,10 @@ class ValidationState_t {
   // The section of the code being processed
   ModuleLayoutSection current_layout_section_;
 
-  Functions module_functions_;
+  std::list<Function> module_functions_;
 
-  spv_capability_mask_t module_capabilities_;  // Module's declared capabilities.
+  spv_capability_mask_t
+      module_capabilities_;  // Module's declared capabilities.
 
   // Definitions and uses of all the IDs in the module.
   UseDefTracker usedefs_;
@@ -314,11 +556,43 @@ class ValidationState_t {
   SpvAddressingModel addressing_model_;
   SpvMemoryModel memory_model_;
 
+  // NOTE: See correspoding getter functions
+  bool in_function_;
 };
 
-}  // namespace libspirv
+/// @brief Calculates dominator edges of a root basic block
+///
+/// This function calculates the dominator edges form a root BasicBlock. Uses
+/// the dominator algorithm by Cooper et al.
+///
+/// @param[in] first_block the root or entry BasicBlock of a function
+///
+/// @return a set of dominator edges represented as a pair of blocks
+std::vector<std::pair<BasicBlock*, BasicBlock*> > CalculateDominators(
+    const BasicBlock& first_block);
+
+/// @brief Performs the Control Flow Graph checks
+///
+/// @param[in] _ the validation state of the module
+///
+/// @return SPV_SUCCESS if no errors are found. SPV_ERROR_INVALID_CFG otherwise
+spv_result_t PerformCfgChecks(ValidationState_t& _);
 
-// Functions
+// @brief Updates the immediate dominator for each of the block edges
+//
+// Updates the immediate dominator of the blocks for each of the edges
+// provided by the @p dom_edges parameter
+//
+// @param[in,out] dom_edges The edges of the dominator tree
+void UpdateImmediateDominators(
+    std::vector<std::pair<BasicBlock*, BasicBlock*> >& dom_edges);
+
+// @brief Prints all of the dominators of a BasicBlock
+//
+// @param[in] block The dominators of this block will be printed
+void printDominatorList(BasicBlock& block);
+
+}  // namespace libspirv
 
 /// @brief Validate the ID usage of the instruction stream
 ///
index 0f8e980..f9a1d90 100644 (file)
 // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
 
+#include "validate.h"
 #include "validate_passes.h"
 
+#include <algorithm>
+#include <cassert>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+using std::find;
+using std::get;
+using std::make_pair;
+using std::numeric_limits;
+using std::pair;
+using std::transform;
+using std::unordered_map;
+using std::unordered_set;
+using std::vector;
+
+using libspirv::BasicBlock;
+
 namespace libspirv {
 
-// TODO(umar): Support for merge instructions
-// TODO(umar): Structured control flow checks
-spv_result_t CfgPass(ValidationState_t& _,
-                     const spv_parsed_instruction_t* inst) {
-  if (_.getLayoutSection() == kLayoutFunctionDefinitions) {
-    SpvOp opcode = static_cast<SpvOp>(inst->opcode);
-    switch (opcode) {
-      case SpvOpLabel:
-        spvCheckReturn(_.get_functions().RegisterBlock(inst->result_id));
-        break;
-      case SpvOpBranch:
-      case SpvOpBranchConditional:
-      case SpvOpSwitch:
-      case SpvOpKill:
-      case SpvOpReturn:
-      case SpvOpReturnValue:
-      case SpvOpUnreachable:
-        spvCheckReturn(_.get_functions().RegisterBlockEnd());
-        break;
-      default:
-        break;
+namespace {
+
+using bb_ptr = BasicBlock*;
+using cbb_ptr = const BasicBlock*;
+using bb_iter = vector<BasicBlock*>::const_iterator;
+
+/// @brief Sorts the blocks in a CFG given the entry node
+///
+/// Returns a vector of basic block pointers in a Control Flow Graph(CFG) which
+/// are sorted in the order they were accessed in a post order traversal.
+///
+/// @param[in] entry the first block of a CFG
+/// @param[in] depth_hint a hint about the depth of the CFG
+///
+/// @return A vector of pointers in the order they were access in a post order
+/// traversal
+vector<const BasicBlock*> PostOrderSort(const BasicBlock& entry, size_t size) {
+  struct block_info {
+    cbb_ptr block;
+    bb_iter iter;
+  };
+
+  vector<cbb_ptr> out;
+  vector<block_info> staged;
+  unordered_set<uint32_t> processed;
+
+  staged.reserve(size);
+  staged.emplace_back(block_info{&entry, begin(entry.get_successors())});
+  processed.insert(entry.get_id());
+
+  while (!staged.empty()) {
+    block_info& top = staged.back();
+    if (top.iter == end(top.block->get_successors())) {
+      out.push_back(top.block);
+      staged.pop_back();
+    } else {
+      BasicBlock* child = *top.iter;
+      if (processed.find(child->get_id()) == end(processed)) {
+        staged.emplace_back(block_info{child, begin(child->get_successors())});
+        processed.insert(child->get_id());
+      }
+      top.iter++;
     }
   }
+  return out;
+}
+}  // namespace
+
+vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
+    const BasicBlock& first_block) {
+  struct block_detail {
+    size_t dominator;  ///< The index of blocks's dominator in post order array
+    size_t postorder_index;  ///< The index of the block in the post order array
+  };
+
+  vector<cbb_ptr> postorder = PostOrderSort(first_block, 10);
+  const size_t undefined_dom = static_cast<size_t>(postorder.size());
+
+  unordered_map<cbb_ptr, block_detail> idoms;
+  for (size_t i = 0; i < postorder.size(); i++) {
+    idoms[postorder[i]] = {undefined_dom, i};
+  }
+
+  idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index;
+
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) {
+      size_t& b_dom = idoms[*b].dominator;
+      const vector<BasicBlock*>& predecessors = (*b)->get_predecessors();
+
+      // first processed predecessor
+      auto res = find_if(begin(predecessors), end(predecessors),
+                         [&idoms, undefined_dom](BasicBlock* pred) {
+                           return idoms[pred].dominator != undefined_dom;
+                         });
+      assert(res != end(predecessors));
+      BasicBlock* idom = *res;
+      size_t idom_idx = idoms[idom].postorder_index;
+
+      // all other predecessors
+      for (auto p : predecessors) {
+        if (idom == p || p->is_reachable() == false) {
+          continue;
+        }
+        if (idoms[p].dominator != undefined_dom) {
+          size_t finger1 = idoms[p].postorder_index;
+          size_t finger2 = idom_idx;
+          while (finger1 != finger2) {
+            while (finger1 < finger2) {
+              finger1 = idoms[postorder[finger1]].dominator;
+            }
+            while (finger2 < finger1) {
+              finger2 = idoms[postorder[finger2]].dominator;
+            }
+          }
+          idom_idx = finger1;
+        }
+      }
+      if (b_dom != idom_idx) {
+        b_dom = idom_idx;
+        changed = true;
+      }
+    }
+  }
+
+  vector<pair<bb_ptr, bb_ptr>> out;
+  for (auto idom : idoms) {
+    // NOTE: performing a const cast for convenient usage with
+    // UpdateImmediateDominators
+    out.push_back({const_cast<BasicBlock*>(get<0>(idom)),
+                   const_cast<BasicBlock*>(postorder[get<1>(idom).dominator])});
+  }
+  return out;
+}
+
+void UpdateImmediateDominators(vector<pair<bb_ptr, bb_ptr>>& dom_edges) {
+  for (auto& edge : dom_edges) {
+    get<0>(edge)->SetImmediateDominator(get<1>(edge));
+  }
+}
+
+void printDominatorList(BasicBlock& b) {
+  std::cout << b.get_id() << " is dominated by: ";
+  const BasicBlock* bb = &b;
+  while (bb->GetImmediateDominator() != bb) {
+    bb = bb->GetImmediateDominator();
+    std::cout << bb->get_id() << " ";
+  }
+}
+
+#define CFG_ASSERT(ASSERT_FUNC, TARGET) \
+  if (spv_result_t rcode = ASSERT_FUNC(_, TARGET)) return rcode
+
+spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
+  if (_.get_current_function().IsFirstBlock(target)) {
+    return _.diag(SPV_ERROR_INVALID_CFG)
+           << "First block " << _.getIdName(target) << " of funciton "
+           << _.getIdName(_.get_current_function().get_id())
+           << " is targeted by block "
+           << _.getIdName(
+                  _.get_current_function().get_current_block().get_id());
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
+  if (_.get_current_function().IsMergeBlock(merge_block)) {
+    return _.diag(SPV_ERROR_INVALID_CFG)
+           << "Block " << _.getIdName(merge_block)
+           << " is already a merge block for another header";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t PerformCfgChecks(ValidationState_t& _) {
+  for (auto& function : _.get_functions()) {
+    // Updates each blocks immediate dominators
+    if (auto* first_block = function.get_first_block()) {
+      auto edges = libspirv::CalculateDominators(*first_block);
+      libspirv::UpdateImmediateDominators(edges);
+    }
+
+    // Check if the order of blocks in the binary appear before the blocks they
+    // dominate
+    auto& blocks = function.get_blocks();
+    if (blocks.empty() == false) {
+      for (auto block = begin(blocks) + 1; block != end(blocks); block++) {
+        if (auto idom = (*block)->GetImmediateDominator()) {
+          if (block == std::find(begin(blocks), block, idom)) {
+            return _.diag(SPV_ERROR_INVALID_CFG)
+                   << "Block " << _.getIdName((*block)->get_id())
+                   << " appears in the binary before its dominator "
+                   << _.getIdName(idom->get_id());
+          }
+        }
+      }
+    }
+
+    // Check all referenced blocks are defined within a function
+    if (function.get_undefined_block_count() != 0) {
+      std::stringstream ss;
+      ss << "{";
+      for (auto undefined_block : function.get_undefined_blocks()) {
+        ss << _.getIdName(undefined_block) << " ";
+      }
+      return _.diag(SPV_ERROR_INVALID_CFG)
+             << "Block(s) " << ss.str() << "\b}"
+             << " are referenced but not defined in function "
+             << _.getIdName(function.get_id());
+    }
+
+    // Check all headers dominate their merge blocks
+    for (CFConstruct& construct : function.get_constructs()) {
+      auto header = construct.get_header();
+      auto merge = construct.get_merge();
+      // auto cont = construct.get_continue();
+
+      if (merge->is_reachable() &&
+          find(merge->dom_begin(), merge->dom_end(), header) ==
+              merge->dom_end()) {
+        return _.diag(SPV_ERROR_INVALID_CFG)
+               << "Header block " << _.getIdName(header->get_id())
+               << " doesn't dominate its merge block "
+               << _.getIdName(merge->get_id());
+      }
+    }
+
+    // TODO(umar): All CFG back edges must branch to a loop header, with each
+    // loop header having exactly one back edge branching to it
+
+    // TODO(umar): For a given loop, its back-edge block must post dominate the
+    // OpLoopMerge's Continue Target, and that Continue Target must dominate the
+    // back-edge block
+  }
   return SPV_SUCCESS;
 }
+
+spv_result_t CfgPass(ValidationState_t& _,
+                     const spv_parsed_instruction_t* inst) {
+  SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+  switch (opcode) {
+    case SpvOpLabel:
+      spvCheckReturn(_.get_current_function().RegisterBlock(inst->result_id));
+      break;
+    case SpvOpLoopMerge: {
+      // TODO(umar): mark current block as a loop header
+      uint32_t merge_block = inst->words[inst->operands[0].offset];
+      uint32_t continue_block = inst->words[inst->operands[1].offset];
+      CFG_ASSERT(MergeBlockAssert, merge_block);
+
+      spvCheckReturn(_.get_current_function().RegisterLoopMerge(
+          merge_block, continue_block));
+    } break;
+    case SpvOpSelectionMerge: {
+      uint32_t merge_block = inst->words[inst->operands[0].offset];
+      CFG_ASSERT(MergeBlockAssert, merge_block);
+
+      spvCheckReturn(
+          _.get_current_function().RegisterSelectionMerge(merge_block));
+    } break;
+    case SpvOpBranch: {
+      uint32_t target = inst->words[inst->operands[0].offset];
+      CFG_ASSERT(FirstBlockAssert, target);
+
+      _.get_current_function().RegisterBlockEnd({target}, opcode);
+    } break;
+    case SpvOpBranchConditional: {
+      uint32_t tlabel = inst->words[inst->operands[1].offset];
+      uint32_t flabel = inst->words[inst->operands[2].offset];
+      CFG_ASSERT(FirstBlockAssert, tlabel);
+      CFG_ASSERT(FirstBlockAssert, flabel);
+
+      _.get_current_function().RegisterBlockEnd({tlabel, flabel}, opcode);
+    } break;
+
+    case SpvOpSwitch: {
+      vector<uint32_t> cases;
+      for (int i = 1; i < inst->num_operands; i += 2) {
+        uint32_t target = inst->words[inst->operands[i].offset];
+        CFG_ASSERT(FirstBlockAssert, target);
+        cases.push_back(target);
+      }
+      _.get_current_function().RegisterBlockEnd({cases}, opcode);
+    } break;
+    case SpvOpKill:
+    case SpvOpReturn:
+    case SpvOpReturnValue:
+    case SpvOpUnreachable:
+      _.get_current_function().RegisterBlockEnd({}, opcode);
+      break;
+    default:
+      break;
+  }
+  return SPV_SUCCESS;
 }
+}  // namespace libspirv
index 26053e3..1a17bd0 100644 (file)
@@ -120,7 +120,7 @@ spv_result_t InstructionPass(ValidationState_t& _,
                              const spv_parsed_instruction_t* inst) {
   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
   if (opcode == SpvOpCapability)
-    _.registerCapability(
+    _.RegisterCapability(
         static_cast<SpvCapability>(inst->words[inst->operands[0].offset]));
   if (opcode == SpvOpMemoryModel) {
     _.setAddressingModel(
@@ -140,6 +140,10 @@ spv_result_t InstructionPass(ValidationState_t& _,
                << "Variables must have a function[7] storage class inside"
                   " of a function";
       }
+      if(_.get_current_function().IsFirstBlock(_.get_current_function().get_current_block().get_id()) == false) {
+        return _.diag(SPV_ERROR_INVALID_CFG)
+          << "Variables can only be defined in the first block of a function";
+      }
     } else {
       if (storage_class == SpvStorageClassFunction) {
         return _.diag(SPV_ERROR_INVALID_LAYOUT)
index 2bb6db2..0b0b609 100644 (file)
@@ -81,19 +81,20 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
                                         SpvOp opcode) {
   if (_.isOpcodeInCurrentLayoutSection(opcode)) {
     switch (opcode) {
-      case SpvOpFunction:
+      case SpvOpFunction: {
         if (_.in_function_body()) {
           return _.diag(SPV_ERROR_INVALID_LAYOUT)
                  << "Cannot declare a function in a function body";
         }
-        spvCheckReturn(_.get_functions().RegisterFunction(
+        auto control_mask = static_cast<SpvFunctionControlMask>(inst->words[inst->operands[2].offset]);
+        spvCheckReturn(_.RegisterFunction(
             inst->result_id, inst->type_id,
-            inst->words[inst->operands[2].offset],
+            control_mask,
             inst->words[inst->operands[3].offset]));
         if (_.getLayoutSection() == kLayoutFunctionDefinitions)
-          spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
+          spvCheckReturn(_.get_current_function().RegisterSetFunctionDeclType(
               FunctionDecl::kFunctionDeclDefinition));
-        break;
+      } break;
 
       case SpvOpFunctionParameter:
         if (_.in_function_body() == false) {
@@ -101,12 +102,12 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
                                                      "instructions must be in "
                                                      "a function body";
         }
-        if (_.get_functions().get_block_count() != 0) {
+        if (_.get_current_function().get_block_count() != 0) {
           return _.diag(SPV_ERROR_INVALID_LAYOUT)
                  << "Function parameters must only appear immediately after the "
                     "function definition";
         }
-        spvCheckReturn(_.get_functions().RegisterFunctionParameter(
+        spvCheckReturn(_.get_current_function().RegisterFunctionParameter(
             inst->result_id, inst->type_id));
         break;
 
@@ -119,17 +120,17 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
           return _.diag(SPV_ERROR_INVALID_LAYOUT)
                  << "Function end cannot be called in blocks";
         }
-        if (_.get_functions().get_block_count() == 0 &&
+        if (_.get_current_function().get_block_count() == 0 &&
             _.getLayoutSection() == kLayoutFunctionDefinitions) {
           return _.diag(SPV_ERROR_INVALID_LAYOUT) << "Function declarations "
                                                      "must appear before "
                                                      "function definitions.";
         }
-        spvCheckReturn(_.get_functions().RegisterFunctionEnd());
         if (_.getLayoutSection() == kLayoutFunctionDeclarations) {
-          spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
-              FunctionDecl::kFunctionDeclDeclaration));
+          spvCheckReturn(_.get_current_function().RegisterSetFunctionDeclType(
+                                                    FunctionDecl::kFunctionDeclDeclaration));
         }
+        spvCheckReturn(_.RegisterFunctionEnd());
         break;
 
       case SpvOpLine:
@@ -149,7 +150,7 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
         }
         if (_.getLayoutSection() == kLayoutFunctionDeclarations) {
           _.progressToNextLayoutSectionOrder();
-          spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
+          spvCheckReturn(_.get_current_function().RegisterSetFunctionDeclType(
               FunctionDecl::kFunctionDeclDefinition));
         }
         break;
index 9ad74be..88f6ad1 100644 (file)
 
 #include <algorithm>
 #include <cassert>
+#include <iterator>
+#include <limits>
+#include <list>
 #include <map>
 #include <string>
 #include <unordered_set>
 #include <vector>
 
 #include "spirv/spirv.h"
-
 #include "spirv_definition.h"
 #include "validate.h"
 
 using std::find;
+using std::list;
+using std::numeric_limits;
 using std::string;
 using std::unordered_set;
 using std::vector;
@@ -209,6 +213,10 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
 
 namespace libspirv {
 
+void message(std::string file, size_t line, std::string name) {
+  std::cout << file << ":" << line << ": " << name << std::endl;
+}
+
 ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
                                      const spv_const_context context)
     : diagnostic_(diagnostic),
@@ -216,11 +224,12 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
       unresolved_forward_ids_{},
       operand_names_{},
       current_layout_section_(kLayoutCapabilities),
-      module_functions_(*this),
+      module_functions_(),
       module_capabilities_(0u),
       grammar_(context),
       addressing_model_(SpvAddressingModelLogical),
-      memory_model_(SpvMemoryModelSimple) {}
+      memory_model_(SpvMemoryModelSimple),
+      in_function_(false) {}
 
 spv_result_t ValidationState_t::forwardDeclareId(uint32_t id) {
   unresolved_forward_ids_.insert(id);
@@ -245,6 +254,16 @@ string ValidationState_t::getIdName(uint32_t id) const {
   return out.str();
 }
 
+string ValidationState_t::getIdOrName(uint32_t id) const {
+  std::stringstream out;
+  if (operand_names_.find(id) != end(operand_names_)) {
+    out << operand_names_.at(id);
+  } else {
+    out << id;
+  }
+  return out.str();
+}
+
 size_t ValidationState_t::unresolvedForwardIdCount() const {
   return unresolved_forward_ids_.size();
 }
@@ -286,23 +305,26 @@ DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const {
       error_code);
 }
 
-Functions& ValidationState_t::get_functions() { return module_functions_; }
+list<Function>& ValidationState_t::get_functions() { return module_functions_; }
 
-bool ValidationState_t::in_function_body() const {
-  return module_functions_.in_function_body();
+Function& ValidationState_t::get_current_function() {
+  assert(in_function_body());
+  return module_functions_.back();
 }
 
+bool ValidationState_t::in_function_body() const { return in_function_; }
+
 bool ValidationState_t::in_block() const {
-  return module_functions_.in_block();
+  return module_functions_.back().in_block();
 }
 
-void ValidationState_t::registerCapability(SpvCapability cap) {
+void ValidationState_t::RegisterCapability(SpvCapability cap) {
   module_capabilities_ |= SPV_CAPABILITY_AS_MASK(cap);
   spv_operand_desc desc;
   if (SPV_SUCCESS ==
       grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc))
     libspirv::ForEach(desc->capabilities,
-                      [this](SpvCapability c) { registerCapability(c); });
+                      [this](SpvCapability c) { RegisterCapability(c); });
 }
 
 bool ValidationState_t::hasCapability(SpvCapability cap) const {
@@ -318,7 +340,7 @@ bool ValidationState_t::HasAnyOf(spv_capability_mask_t capabilities) const {
   });
   return found;
 }
-       
+
 void ValidationState_t::setAddressingModel(SpvAddressingModel am) {
   addressing_model_ = am;
 }
@@ -335,40 +357,61 @@ SpvMemoryModel ValidationState_t::getMemoryModel() const {
   return memory_model_;
 }
 
-Functions::Functions(ValidationState_t& module)
-    : module_(module), in_function_(false), in_block_(false) {}
-
-bool Functions::in_function_body() const { return in_function_; }
-
-bool Functions::in_block() const { return in_block_; }
+Function::Function(uint32_t id, uint32_t result_type_id,
+                   SpvFunctionControlMask function_control,
+                   uint32_t function_type_id, ValidationState_t& module)
+    : module_(module),
+      id_(id),
+      function_type_id_(function_type_id),
+      result_type_id_(result_type_id),
+      function_control_(function_control),
+      declaration_type_(FunctionDecl::kFunctionDeclUnknown),
+      blocks_(),
+      current_block_(nullptr),
+      cfg_constructs_(),
+      variable_ids_(),
+      parameter_ids_() {}
+
+bool Function::in_block() const { return static_cast<bool>(current_block_); }
+
+bool Function::IsFirstBlock(uint32_t id) const {
+  return !ordered_blocks_.empty() && *get_first_block() == id;
+}
 
-spv_result_t Functions::RegisterFunction(uint32_t id, uint32_t ret_type_id,
-                                         uint32_t function_control,
-                                         uint32_t function_type_id) {
-  assert(in_function_ == false &&
-         "Function instructions can not be declared in a function");
+spv_result_t ValidationState_t::RegisterFunction(
+    uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
+    uint32_t function_type_id) {
+  assert(in_function_body() == false &&
+         "RegisterFunction can only be called when parsing the binary outside "
+         "of another function");
   in_function_ = true;
-  id_.emplace_back(id);
-  type_id_.emplace_back(function_type_id);
-  declaration_type_.emplace_back(FunctionDecl::kFunctionDeclUnknown);
-  block_ids_.emplace_back();
-  variable_ids_.emplace_back();
-  parameter_ids_.emplace_back();
+  module_functions_.emplace_back(id, ret_type_id, function_control,
+                                 function_type_id, *this);
 
   // TODO(umar): validate function type and type_id
-  (void)ret_type_id;
-  (void)function_control;
 
   return SPV_SUCCESS;
 }
 
-spv_result_t Functions::RegisterFunctionParameter(uint32_t id,
-                                                  uint32_t type_id) {
-  assert(in_function_ == true &&
-         "Function parameter instructions cannot be declared outside of a "
-         "function");
+spv_result_t ValidationState_t::RegisterFunctionEnd() {
+  assert(in_function_body() == true &&
+         "RegisterFunctionEnd can only be called when parsing the binary "
+         "inside of another function");
+  assert(in_block() == false &&
+         "RegisterFunctionParameter can only be called when parsing the binary "
+         "ouside of a block");
+  in_function_ = false;
+  return SPV_SUCCESS;
+}
+
+spv_result_t Function::RegisterFunctionParameter(uint32_t id,
+                                                 uint32_t type_id) {
+  assert(module_.in_function_body() == true &&
+         "RegisterFunctionParameter can only be called when parsing the binary "
+         "outside of another function");
   assert(in_block() == false &&
-         "Function parameters cannot be called in blocks");
+         "RegisterFunctionParameter can only be called when parsing the binary "
+         "ouside of a block");
   // TODO(umar): Validate function parameter type order and count
   // TODO(umar): Use these variables to validate parameter type
   (void)id;
@@ -376,42 +419,231 @@ spv_result_t Functions::RegisterFunctionParameter(uint32_t id,
   return SPV_SUCCESS;
 }
 
-spv_result_t Functions::RegisterSetFunctionDeclType(FunctionDecl type) {
-  assert(declaration_type_.back() == FunctionDecl::kFunctionDeclUnknown);
-  declaration_type_.back() = type;
+spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
+                                         uint32_t continue_id) {
+  RegisterBlock(merge_id, false);
+  RegisterBlock(continue_id, false);
+  cfg_constructs_.emplace_back(&get_current_block(), &blocks_.at(merge_id),
+                               &blocks_.at(continue_id));
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) {
+  RegisterBlock(merge_id, false);
+  cfg_constructs_.emplace_back(&get_current_block(), &blocks_.at(merge_id));
   return SPV_SUCCESS;
 }
 
-spv_result_t Functions::RegisterBlock(uint32_t id) {
-  assert(in_function_ == true && "Blocks can only exsist in functions");
-  assert(in_block_ == false && "Blocks cannot be nested");
+void printDot(const BasicBlock& other, const ValidationState_t& module) {
+  string block_string;
+  if (other.get_successors().empty()) {
+    block_string += "end ";
+  } else {
+    for (auto& block : other.get_successors()) {
+      block_string += module.getIdOrName(block->get_id()) + " ";
+    }
+  }
+  printf("%10s -> {%s\b}\n", module.getIdOrName(other.get_id()).c_str(),
+         block_string.c_str());
+}
+
+void Function::printDotGraph() const {
+  if (get_first_block()) {
+    string func_name(module_.getIdOrName(id_));
+    printf("digraph %s {\n", func_name.c_str());
+    printBlocks();
+    printf("}\n");
+  }
+}
+
+void Function::printBlocks() const {
+  if (get_first_block()) {
+    printf("%10s -> %s\n", module_.getIdOrName(id_).c_str(),
+           module_.getIdOrName(get_first_block()->get_id()).c_str());
+    for (const auto& block : blocks_) {
+      printDot(block.second, module_);
+    }
+  }
+}
+
+spv_result_t Function::RegisterSetFunctionDeclType(FunctionDecl type) {
+  assert(declaration_type_ == FunctionDecl::kFunctionDeclUnknown);
+  declaration_type_ = type;
+  return SPV_SUCCESS;
+}
+
+spv_result_t Function::RegisterBlock(uint32_t id, bool is_definition) {
+  assert(module_.in_function_body() == true &&
+         "RegisterBlocks can only be called when parsing a binary inside of a "
+         "function");
   assert(module_.getLayoutSection() !=
              ModuleLayoutSection::kLayoutFunctionDeclarations &&
-         "Function declartions must appear before function definitions");
-  assert(declaration_type_.back() == FunctionDecl::kFunctionDeclDefinition &&
-         "Function declaration type should have already been defined");
+         "RegisterBlocks cannot be called within a function declaration");
+  assert(
+      declaration_type_ == FunctionDecl::kFunctionDeclDefinition &&
+      "RegisterBlocks can only be called after declaration_type_ is defined");
+
+  std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block;
+  bool success = false;
+  tie(inserted_block, success) = blocks_.insert({id, BasicBlock(id)});
+  if (is_definition) {  // new block definition
+    assert(in_block() == false &&
+           "Register Block can only be called when parsing a binary outside of "
+           "a BasicBlock");
+
+    undefined_blocks_.erase(id);
+    current_block_ = &inserted_block->second;
+    ordered_blocks_.push_back(current_block_);
+    if (IsFirstBlock(id)) current_block_->set_reachability(true);
+  } else if (success) {  // Block doesn't exsist but this is not a definition
+    undefined_blocks_.insert(id);
+  }
 
-  block_ids_.back().push_back(id);
-  in_block_ = true;
   return SPV_SUCCESS;
 }
 
-spv_result_t Functions::RegisterFunctionEnd() {
-  assert(in_function_ == true &&
-         "Function end can only be called in functions");
-  assert(in_block_ == false && "Function end cannot be called inside a block");
-  in_function_ = false;
-  return SPV_SUCCESS;
+void Function::RegisterBlockEnd(vector<uint32_t> next_list,
+                                SpvOp branch_instruction) {
+  assert(module_.in_function_body() == true &&
+         "RegisterBlockEnd can only be called when parsing a binary in a "
+         "function");
+  assert(
+      in_block() == true &&
+      "RegisterBlockEnd can only be called when parsing a binary in a block");
+
+  vector<BasicBlock*> next_blocks;
+  next_blocks.reserve(next_list.size());
+
+  std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block;
+  bool success;
+  for (uint32_t id : next_list) {
+    tie(inserted_block, success) = blocks_.insert({id, BasicBlock(id)});
+    if (success) {
+      undefined_blocks_.insert(id);
+    }
+    next_blocks.push_back(&inserted_block->second);
+  }
+
+  current_block_->RegisterBranchInstruction(branch_instruction);
+  current_block_->RegisterSuccessors(next_blocks);
+  current_block_ = nullptr;
+  return;
 }
 
-spv_result_t Functions::RegisterBlockEnd() {
-  assert(in_function_ == true &&
-         "Branch instruction can only be called in a function");
-  assert(in_block_ == true &&
-         "Branch instruction can only be called in a block");
-  in_block_ = false;
-  return SPV_SUCCESS;
+size_t Function::get_block_count() const { return blocks_.size(); }
+
+size_t Function::get_undefined_block_count() const {
+  return undefined_blocks_.size();
+}
+
+const vector<BasicBlock*>& Function::get_blocks() const {
+  return ordered_blocks_;
+}
+vector<BasicBlock*>& Function::get_blocks() { return ordered_blocks_; }
+
+const BasicBlock& Function::get_current_block() const {
+  return *current_block_;
+}
+BasicBlock& Function::get_current_block() { return *current_block_; }
+
+const list<CFConstruct>& Function::get_constructs() const {
+  return cfg_constructs_;
+}
+list<CFConstruct>& Function::get_constructs() { return cfg_constructs_; }
+
+const BasicBlock* Function::get_first_block() const {
+  if (ordered_blocks_.empty()) return nullptr;
+  return ordered_blocks_[0];
+}
+BasicBlock* Function::get_first_block() {
+  if (ordered_blocks_.empty()) return nullptr;
+  return ordered_blocks_[0];
+}
+
+BasicBlock::BasicBlock(uint32_t id)
+    : id_(id),
+      immediate_dominator_(nullptr),
+      predecessors_(),
+      successors_(),
+      reachable_(false) {}
+
+void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) {
+  immediate_dominator_ = dom_block;
+}
+
+const BasicBlock* BasicBlock::GetImmediateDominator() const {
+  return immediate_dominator_;
+}
+
+BasicBlock* BasicBlock::GetImmediateDominator() { return immediate_dominator_; }
+
+void BasicBlock::RegisterSuccessors(vector<BasicBlock*> next_blocks) {
+  for (auto& block : next_blocks) {
+    block->predecessors_.push_back(this);
+    successors_.push_back(block);
+    if (block->reachable_ == false) block->set_reachability(reachable_);
+  }
+}
+
+void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
+  if (branch_instruction == SpvOpUnreachable) reachable_ = false;
+  return;
+}
+
+bool Function::IsMergeBlock(uint32_t merge_block_id) const {
+  const auto b = blocks_.find(merge_block_id);
+  if (b != end(blocks_)) {
+    return cfg_constructs_.end() !=
+           find_if(begin(cfg_constructs_), end(cfg_constructs_),
+                   [&](const CFConstruct& construct) {
+                     return construct.get_merge() == &b->second;
+                   });
+  } else {
+    return false;
+  }
+}
+
+BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {}
+BasicBlock::DominatorIterator::DominatorIterator(const BasicBlock* block)
+    : current_(block) {}
+
+BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() {
+  if (current_ == current_->GetImmediateDominator()) {
+    current_ = nullptr;
+  } else {
+    current_ = current_->GetImmediateDominator();
+  }
+  return *this;
+}
+
+const BasicBlock::DominatorIterator BasicBlock::dom_begin() const {
+  return DominatorIterator(this);
+}
+
+BasicBlock::DominatorIterator BasicBlock::dom_begin() {
+  return DominatorIterator(this);
+}
+
+const BasicBlock::DominatorIterator BasicBlock::dom_end() const {
+  return DominatorIterator();
+}
+
+BasicBlock::DominatorIterator BasicBlock::dom_end() {
+  return DominatorIterator();
+}
+
+bool operator==(const BasicBlock::DominatorIterator& lhs,
+                const BasicBlock::DominatorIterator& rhs) {
+  return lhs.current_ == rhs.current_;
+}
+
+bool operator!=(const BasicBlock::DominatorIterator& lhs,
+                const BasicBlock::DominatorIterator& rhs) {
+  return !(lhs == rhs);
 }
 
-size_t Functions::get_block_count() const { return block_ids_.back().size(); }
+const BasicBlock*& BasicBlock::DominatorIterator::operator*() {
+  return current_;
 }
+}  // namespace libspirv
index 3e43fd3..2261b5b 100644 (file)
@@ -109,6 +109,7 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
       ${CMAKE_CURRENT_SOURCE_DIR}/UnitSPIRV.cpp
       ${CMAKE_CURRENT_SOURCE_DIR}/ValidateFixtures.cpp
       ${CMAKE_CURRENT_SOURCE_DIR}/Validate.Capability.cpp
+      ${CMAKE_CURRENT_SOURCE_DIR}/Validate.CFG.cpp
       ${CMAKE_CURRENT_SOURCE_DIR}/Validate.Layout.cpp
       ${CMAKE_CURRENT_SOURCE_DIR}/Validate.Storage.cpp
       ${CMAKE_CURRENT_SOURCE_DIR}/Validate.SSA.cpp
index 93bf791..967d998 100644 (file)
 
 #include <iomanip>
 
-#include "source/assembly_grammar.h"
-#include "source/binary.h"
-#include "source/diagnostic.h"
-#include "source/opcode.h"
-#include "source/spirv_endian.h"
-#include "source/text.h"
-#include "source/text_handler.h"
-#include "source/validate.h"
-#include "spirv-tools/libspirv.h"
+#include <source/assembly_grammar.h>
+#include <source/binary.h>
+#include <source/diagnostic.h>
+#include <source/opcode.h>
+#include <source/spirv_endian.h>
+#include <source/text.h>
+#include <source/text_handler.h>
+#include <source/validate.h>
+#include <spirv-tools/libspirv.h>
 
 #include <gtest/gtest.h>
 
diff --git a/test/Validate.CFG.cpp b/test/Validate.CFG.cpp
new file mode 100644 (file)
index 0000000..ea1b6e8
--- /dev/null
@@ -0,0 +1,599 @@
+
+// Copyright (c) 2015-2016 The Khronos Group Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a
+// copy of this software and/or associated documentation files (the
+// "Materials"), to deal in the Materials without restriction, including
+// without limitation the rights to use, copy, modify, merge, publish,
+// distribute, sublicense, and/or sell copies of the Materials, and to
+// permit persons to whom the Materials are furnished to do so, subject to
+// the following conditions:
+//
+// The above copyright notice and this permission notice shall be included
+// in all copies or substantial portions of the Materials.
+//
+// MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
+// KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
+// SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
+//    https://www.khronos.org/registry/
+//
+// THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+// MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
+
+// Validation tests for Control Flow Graph
+
+#include <array>
+#include <functional>
+#include <iterator>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+
+#include "TestFixture.h"
+#include "UnitSPIRV.h"
+#include "ValidateFixtures.h"
+#include "source/diagnostic.h"
+#include "source/validate.h"
+
+using std::array;
+using std::make_pair;
+using std::pair;
+using std::string;
+using std::stringstream;
+using std::vector;
+
+using ::testing::HasSubstr;
+using ::testing::MatchesRegex;
+
+using libspirv::BasicBlock;
+using libspirv::ValidationState_t;
+
+using ValidateCFG = spvtest::ValidateBase<bool>;
+using spvtest::ScopedContext;
+
+namespace {
+
+string nameOps() { return ""; }
+
+template <typename... Args>
+string nameOps(pair<string, string> head, Args... names) {
+  return "OpName %" + head.first + " \"" + head.second + "\"\n" +
+         nameOps(names...);
+}
+
+template <typename... Args>
+string nameOps(string head, Args... names) {
+  return "OpName %" + head + " \"" + head + "\"\n" + nameOps(names...);
+}
+
+/// This class allows the easy creation of complex control flow without writing
+/// SPIR-V. This class is used in the test cases below.
+class Block {
+  string label_;
+  string body_;
+  SpvOp type_;
+  vector<Block> successors_;
+
+ public:
+  /// Creates a Block with a given label
+  ///
+  /// @param[in]: label the label id of the block
+  /// @param[in]: type the branch instruciton that ends the block
+  Block(string label, SpvOp type = SpvOpBranch)
+      : label_(label), body_(), type_(type), successors_() {}
+
+  /// Sets the instructions which will appear in the body of the block
+  Block& setBody(std::string body) {
+    body_ = body;
+    return *this;
+  }
+
+  /// Converts the block into a SPIR-V string
+  operator string() {
+    stringstream out;
+    out << std::setw(8) << "%" + label_ + "  = OpLabel \n";
+    if (!body_.empty()) {
+      out << body_;
+    }
+
+    switch (type_) {
+      case SpvOpBranchConditional:
+        out << "OpBranchConditional %cond ";
+        for (Block& b : successors_) {
+          out << "%" + b.label_ + " ";
+        }
+        break;
+      case SpvOpSwitch: {
+        out << "OpSwitch %one %" + successors_.front().label_;
+        stringstream ss;
+        for (size_t i = 1; i < successors_.size(); i++) {
+          ss << " " << i << " %" << successors_[i].label_;
+        }
+        out << ss.str();
+      } break;
+      case SpvOpReturn:
+        out << "OpReturn\n";
+        break;
+      case SpvOpUnreachable:
+        out << "OpUnreachable\n";
+        break;
+      case SpvOpBranch:
+        out << "OpBranch %" + successors_.front().label_;
+        break;
+      default:
+        assert(1 != 1 && "Unhandled");
+    }
+    out << "\n";
+
+    return out.str();
+  }
+  friend Block& operator>>(Block& curr, vector<Block> successors);
+  friend Block& operator>>(Block& lhs, Block& successor);
+};
+
+/// Assigns the successors for the Block on the lhs
+Block& operator>>(Block& lhs, vector<Block> successors) {
+  if (lhs.type_ == SpvOpBranchConditional) {
+    assert(successors.size() == 2);
+  } else if (lhs.type_ == SpvOpSwitch) {
+    assert(successors.size() > 1);
+  }
+  lhs.successors_ = successors;
+  return lhs;
+}
+
+/// Assigns the successor for the Block on the lhs
+Block& operator>>(Block& lhs, Block& successor) {
+  assert(lhs.type_ == SpvOpBranch);
+  lhs.successors_.push_back(successor);
+  return lhs;
+}
+
+string header =
+    "OpCapability Shader\n"
+    "OpMemoryModel Logical GLSL450\n";
+
+string types_consts =
+    "%voidt   = OpTypeVoid\n"
+    "%boolt   = OpTypeBool\n"
+    "%intt    = OpTypeInt 32 1\n"
+    "%one     = OpConstant %intt 1\n"
+    "%two     = OpConstant %intt 2\n"
+    "%ptrt    = OpTypePointer Function %intt\n"
+    "%funct   = OpTypeFunction %voidt\n";
+
+TEST_F(ValidateCFG, Simple) {
+  Block first("first");
+  Block loop("loop", SpvOpBranchConditional);
+  Block cont("cont");
+  Block merge("merge", SpvOpReturn);
+
+  loop.setBody(
+      "%cond    = OpSLessThan %intt %one %two\n"
+      "OpLoopMerge %merge %cont None\n");
+
+  string str = header + nameOps("loop", "first", "cont", "merge",
+                                make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += first >> loop;
+  str += loop >> vector<Block>({cont, merge});
+  str += cont >> loop;
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, Variable) {
+  Block entry("entry");
+  Block cont("cont");
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%var = OpVariable %ptrt Function\n");
+
+  string str = header + nameOps(make_pair("func", "Main")) + types_consts +
+               " %func    = OpFunction %voidt None %funct\n";
+  str += entry >> cont;
+  str += cont >> exit;
+  str += exit;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
+  Block entry("entry");
+  Block cont("cont");
+  Block exit("exit", SpvOpReturn);
+
+  // This operation should only be performed in the entry block
+  cont.setBody("%var = OpVariable %ptrt Function\n");
+
+  string str = header + nameOps(make_pair("func", "Main")) + types_consts +
+               " %func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> cont;
+  str += cont >> exit;
+  str += exit;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Variables can only be defined in the first block of a function"));
+}
+
+TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) {
+  Block entry("entry");
+  Block cont("cont");
+  Block branch("branch", SpvOpBranchConditional);
+  Block merge("merge", SpvOpReturn);
+
+  branch.setBody(
+      " %cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("cont", "branch", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> branch;
+  str += cont >> merge;  // cont appears before its dominator
+  str += branch >> vector<Block>({cont, merge});
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              MatchesRegex("Block [0-9]\\[cont\\] appears in the binary "
+                           "before its dominator [0-9]\\[branch\\]"));
+}
+
+TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
+  Block entry("entry");
+  Block loop("loop");
+  Block selection("selection", SpvOpBranchConditional);
+  Block merge("merge", SpvOpReturn);
+
+  loop.setBody(
+    " %cond   = OpSLessThan %intt %one %two\n"
+    " OpLoopMerge %merge %loop None\n");
+  // cannot share the same merge
+  selection.setBody(
+      "OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("merge", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> loop;
+  str += loop >> selection;
+  str += selection >> vector<Block>({loop, merge});
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              MatchesRegex("Block [0-9]\\[merge\\] is already a merge block "
+                           "for another header"));
+}
+
+TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
+  Block entry("entry");
+  Block loop("loop", SpvOpBranchConditional);
+  Block selection("selection", SpvOpBranchConditional);
+  Block merge("merge", SpvOpReturn);
+
+  selection.setBody(
+      " %cond   = OpSLessThan %intt %one %two\n"
+      " OpSelectionMerge %merge None\n");
+  // cannot share the same merge
+  loop.setBody(" OpLoopMerge %merge %loop None\n");
+
+
+  string str = header + nameOps("merge", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> selection;
+  str += selection >> vector<Block>({merge, loop});
+  str += loop >> vector<Block>({loop, merge});
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              MatchesRegex("Block [0-9]\\[merge\\] is already a merge block "
+                           "for another header"));
+}
+
+TEST_F(ValidateCFG, BranchTargetFirstBlockBad) {
+  Block entry("entry");
+  Block bad("bad");
+  Block end("end", SpvOpReturn);
+  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> bad;
+  str += bad >> entry;  // Cannot target entry block
+  str += end;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex("First block [0-9]\\[entry\\] of funciton [0-9]\\[Main\\] "
+                   "is targeted by block [0-9]\\[bad\\]"));
+}
+
+TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
+  Block entry("entry");
+  Block bad("bad", SpvOpBranchConditional);
+  Block exit("exit", SpvOpReturn);
+
+  bad.setBody(
+      " %cond    = OpSLessThan %intt %one %two\n"
+      " OpLoopMerge %entry %exit None\n");
+
+  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> bad;
+  str += bad >> vector<Block>({entry, exit});  // cannot target entry block
+  str += exit;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex("First block [0-9]\\[entry\\] of funciton [0-9]\\[Main\\] "
+                   "is targeted by block [0-9]\\[bad\\]"));
+}
+
+TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
+  Block entry("entry");
+  Block bad("bad", SpvOpBranchConditional);
+  Block t("t");
+  Block merge("merge");
+  Block end("end", SpvOpReturn);
+
+  bad.setBody(
+      "%cond    = OpSLessThan %intt %one %two\n"
+      "OpLoopMerge %merge %cont None\n");
+
+  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> bad;
+  str += bad >> vector<Block>({t, entry});
+  str += merge >> end;
+  str += end;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex("First block [0-9]\\[entry\\] of funciton [0-9]\\[Main\\] "
+                   "is targeted by block [0-9]\\[bad\\]"));
+}
+
+TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
+  Block entry("entry");
+  Block bad("bad", SpvOpSwitch);
+  Block block1("block1");
+  Block block2("block2");
+  Block block3("block3");
+  Block def("def");  // default block
+  Block merge("merge");
+  Block end("end", SpvOpReturn);
+
+  bad.setBody(
+      "%cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("entry", "bad", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> bad;
+  str += bad >> vector<Block>({def, block1, block2, block3, entry});
+  str += def >> merge;
+  str += block1 >> merge;
+  str += block2 >> merge;
+  str += block3 >> merge;
+  str += merge >> end;
+  str += end;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex("First block [0-9]\\[entry\\] of funciton [0-9]\\[Main\\] "
+                   "is targeted by block [0-9]\\[bad\\]"));
+}
+
+TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) {
+  Block entry("entry");
+  Block middle("middle", SpvOpBranchConditional);
+  Block end("end", SpvOpReturn);
+
+  middle.setBody(
+      "%cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %end None\n");
+
+  Block entry2("entry2");
+  Block middle2("middle2");
+  Block end2("end2", SpvOpReturn);
+
+  string str = header + nameOps("middle2", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> middle;
+  str += middle >> vector<Block>({end, middle2});
+  str += end;
+  str += "OpFunctionEnd\n";
+
+  str += "%func2    = OpFunction %voidt None %funct\n";
+  str += entry2 >> middle2;
+  str += middle2 >> end2;
+  str += end2;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex(
+          "Block\\(s\\) \\{[0-9]\\[middle2\\] .\\} are referenced but not "
+          "defined in function [0-9]\\[Main\\]"));
+}
+
+TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
+  Block entry("entry");
+  Block merge("merge");
+  Block head("head", SpvOpBranchConditional);
+  Block f("f");
+  Block exit("exit", SpvOpReturn);
+
+  entry.setBody("%cond = OpSLessThan %intt %one %two\n");
+
+  head.setBody("OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("head", "merge", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> merge;
+  str += head >> vector<Block>({merge, f});
+  str += f >> merge;
+  str += merge >> exit;
+  str += exit;
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      MatchesRegex(
+          "Header block [0-9]\\[head\\] doesn't dominate its merge block "
+          "[0-9]\\[merge\\]"));
+}
+
+TEST_F(ValidateCFG, UnreachableMerge) {
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+  Block merge("merge");
+  Block end("end", SpvOpReturn);
+
+  branch.setBody(
+      " %cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("branch", "merge", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> branch;
+  str += branch >> vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge >> end;
+  str += end;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+  Block merge("merge", SpvOpUnreachable);
+
+  branch.setBody(
+      " %cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %merge None\n");
+
+  string str = header + nameOps("branch", "merge", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> branch;
+  str += branch >> vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, UnreachableBlock) {
+  Block entry("entry");
+  Block unreachable("unreachable");
+  Block exit("exit", SpvOpReturn);
+
+  string str = header +
+               nameOps("unreachable", "exit", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> exit;
+  str += unreachable >> exit;
+  str += exit;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, UnreachableBranch) {
+  Block entry("entry");
+  Block unreachable("unreachable", SpvOpBranchConditional);
+  Block unreachablechildt("unreachablechildt");
+  Block unreachablechildf("unreachablechildf");
+  Block merge("merge");
+  Block exit("exit", SpvOpReturn);
+
+  unreachable.setBody(
+      " %cond    = OpSLessThan %intt %one %two\n"
+      "OpSelectionMerge %merge None\n");
+  string str = header +
+               nameOps("unreachable", "exit", make_pair("func", "Main")) +
+               types_consts + "%func    = OpFunction %voidt None %funct\n";
+
+  str += entry >> exit;
+  str += unreachable >> vector<Block>({unreachablechildt, unreachablechildf});
+  str += unreachablechildt >> merge;
+  str += unreachablechildf >> merge;
+  str += merge >> exit;
+  str += exit;
+  str += "OpFunctionEnd\n";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+/// TODO(umar): Empty function
+/// TODO(umar): Single block loops
+/// TODO(umar): Nested loops
+/// TODO(umar): Nested selection
+/// TODO(umar): Switch instructions
+/// TODO(umar): CFG branching outside of CFG construct
+/// TODO(umar): Nested CFG constructs
+}
index bbc4f60..58d15b6 100644 (file)
@@ -63,27 +63,27 @@ class ValidationState_HasAnyOfTest : public ValidationStateTest {
 
 TEST_F(ValidationState_HasAnyOfTest, EmptyMask) {
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityMatrix);
+  state_.RegisterCapability(SpvCapabilityMatrix);
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityImageMipmap);
+  state_.RegisterCapability(SpvCapabilityImageMipmap);
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityPipes);
+  state_.RegisterCapability(SpvCapabilityPipes);
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityStorageImageArrayDynamicIndexing);
+  state_.RegisterCapability(SpvCapabilityStorageImageArrayDynamicIndexing);
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityClipDistance);
+  state_.RegisterCapability(SpvCapabilityClipDistance);
   EXPECT_TRUE(state_.HasAnyOf(0));
-  state_.registerCapability(SpvCapabilityStorageImageWriteWithoutFormat);
+  state_.RegisterCapability(SpvCapabilityStorageImageWriteWithoutFormat);
   EXPECT_TRUE(state_.HasAnyOf(0));
 }
 
 TEST_F(ValidationState_HasAnyOfTest, SingleCapMask) {
   EXPECT_FALSE(state_.HasAnyOf(mask({SpvCapabilityMatrix})));
   EXPECT_FALSE(state_.HasAnyOf(mask({SpvCapabilityImageMipmap})));
-  state_.registerCapability(SpvCapabilityMatrix);
+  state_.RegisterCapability(SpvCapabilityMatrix);
   EXPECT_TRUE(state_.HasAnyOf(mask({SpvCapabilityMatrix})));
   EXPECT_FALSE(state_.HasAnyOf(mask({SpvCapabilityImageMipmap})));
-  state_.registerCapability(SpvCapabilityImageMipmap);
+  state_.RegisterCapability(SpvCapabilityImageMipmap);
   EXPECT_TRUE(state_.HasAnyOf(mask({SpvCapabilityMatrix})));
   EXPECT_TRUE(state_.HasAnyOf(mask({SpvCapabilityImageMipmap})));
 }
@@ -95,7 +95,7 @@ TEST_F(ValidationState_HasAnyOfTest, MultiCapMask) {
                            SpvCapabilityGeometryStreams});
   EXPECT_FALSE(state_.HasAnyOf(mask1));
   EXPECT_FALSE(state_.HasAnyOf(mask2));
-  state_.registerCapability(SpvCapabilityImageBuffer);
+  state_.RegisterCapability(SpvCapabilityImageBuffer);
   EXPECT_TRUE(state_.HasAnyOf(mask1));
   EXPECT_FALSE(state_.HasAnyOf(mask2));
 }