Function and block layout checks. very basic CFG.
authorUmar Arshad <umar@arrayfire.com>
Wed, 16 Dec 2015 02:44:21 +0000 (21:44 -0500)
committerDavid Neto <dneto@google.com>
Wed, 13 Jan 2016 15:06:58 +0000 (10:06 -0500)
This adds function and block layout checks to the validator. Very
basic CFG code has been added to make sure labels and branches
are correctly ordered.

Also:
* MemoryModel and Variable instruction checks/tests
* Use spvCheckReturn instead of CHECK_RESULT
* Fix invalid SSA tests
* Created libspirv::spvResultToString in diagnostic.h
* Documented various functions and classes
* Fixed error messages
* Fixed using declaration for FunctionDecl enum class

source/diagnostic.cpp
source/diagnostic.h
source/validate.cpp
source/validate_types.cpp
source/validate_types.h
test/Validate.Layout.cpp
test/Validate.SSA.cpp
test/ValidateFixtures.cpp
test/ValidateFixtures.h

index 78a2267..60502c0 100644 (file)
@@ -88,5 +88,68 @@ DiagnosticStream::~DiagnosticStream() {
     *pDiagnostic_ = spvDiagnosticCreate(&position_, stream_.str().c_str());
   }
 }
+std::string
+spvResultToString(spv_result_t res) {
+  std::string out;
+  switch (res) {
+    case SPV_SUCCESS:
+      out = "SPV_SUCCESS";
+      break;
+    case SPV_UNSUPPORTED:
+      out = "SPV_UNSUPPORTED";
+      break;
+    case SPV_END_OF_STREAM:
+      out = "SPV_END_OF_STREAM";
+      break;
+    case SPV_WARNING:
+      out = "SPV_WARNING";
+      break;
+    case SPV_FAILED_MATCH:
+      out = "SPV_FAILED_MATCH";
+      break;
+    case SPV_REQUESTED_TERMINATION:
+      out = "SPV_REQUESTED_TERMINATION";
+      break;
+    case SPV_ERROR_INTERNAL:
+      out = "SPV_ERROR_INTERNAL";
+      break;
+    case SPV_ERROR_OUT_OF_MEMORY:
+      out = "SPV_ERROR_OUT_OF_MEMORY";
+      break;
+    case SPV_ERROR_INVALID_POINTER:
+      out = "SPV_ERROR_INVALID_POINTER";
+      break;
+    case SPV_ERROR_INVALID_BINARY:
+      out = "SPV_ERROR_INVALID_BINARY";
+      break;
+    case SPV_ERROR_INVALID_TEXT:
+      out = "SPV_ERROR_INVALID_TEXT";
+      break;
+    case SPV_ERROR_INVALID_TABLE:
+      out = "SPV_ERROR_INVALID_TABLE";
+      break;
+    case SPV_ERROR_INVALID_VALUE:
+      out = "SPV_ERROR_INVALID_VALUE";
+      break;
+    case SPV_ERROR_INVALID_DIAGNOSTIC:
+      out = "SPV_ERROR_INVALID_DIAGNOSTIC";
+      break;
+    case SPV_ERROR_INVALID_LOOKUP:
+      out = "SPV_ERROR_INVALID_LOOKUP";
+      break;
+    case SPV_ERROR_INVALID_ID:
+      out = "SPV_ERROR_INVALID_ID";
+      break;
+    case SPV_ERROR_INVALID_CFG:
+      out = "SPV_ERROR_INVALID_CFG";
+      break;
+    case SPV_ERROR_INVALID_LAYOUT:
+      out = "SPV_ERROR_INVALID_LAYOUT";
+      break;
+    default:
+      out = "Unknown Error";
+  }
+  return out;
+}
 
 }  // namespace libspirv
index 9ee1532..c80b77d 100644 (file)
@@ -102,6 +102,8 @@ class DiagnosticStream {
   libspirv::diagnostic_helper helper(position, pDiagnostic); \
   helper.stream()
 
+std::string spvResultToString(spv_result_t res);
+
 }  // namespace libspirv
 
 #endif  // LIBSPIRV_DIAGNOSTIC_H_
index a208e47..0447f0d 100644 (file)
@@ -58,6 +58,10 @@ using std::unordered_set;
 using std::vector;
 
 using libspirv::ValidationState_t;
+using libspirv::kLayoutFunctionDeclarations;
+using libspirv::kLayoutFunctionDefinitions;
+using libspirv::kLayoutMemoryModel;
+using libspirv::FunctionDecl;
 
 #define spvCheckReturn(expression) \
   if (spv_result_t error = (expression)) return error;
@@ -436,10 +440,7 @@ void DebugInstructionPass(ValidationState_t& _,
   }
 }
 
-// TODO(umar): Check MemoryModel is in module
-// TODO(umar): Check OpVariable storage class is not function in module section
-// TODO(umar): Make sure function declarations appear before function
-// definitions
+// TODO(umar): Check linkage capabilities for function declarations
 // TODO(umar): Better error messages
 // NOTE: This function does not handle CFG related validation
 // Performs logical layout validation. See Section 2.4
@@ -448,22 +449,38 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _,
   if (_.is_enabled(SPV_VALIDATE_LAYOUT_BIT)) {
     SpvOp opcode = inst->opcode;
 
-    if (libspirv::ModuleLayoutSection::kModule == _.getLayoutStage()) {
+    if (_.getLayoutStage() < kLayoutFunctionDeclarations) {
       // Module scoped instructions are processed by determining if the opcode
       // is part of the current stage. If it is not then the next stage is
       // checked.
       while (_.isOpcodeInCurrentLayoutStage(opcode) == false) {
-        // TODO(umar): Check if the MemoryModel instruction has executed
         _.progressToNextLayoutStageOrder();
-        if (_.getLayoutStage() == libspirv::ModuleLayoutSection::kFunction) {
+
+        if (_.getLayoutStage() == kLayoutMemoryModel &&
+            opcode != SpvOpMemoryModel) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                 << spvOpcodeString(opcode)
+                 << " cannot appear before the memory model instruction";
+        }
+
+        if (_.getLayoutStage() == kLayoutFunctionDeclarations) {
           // All module stages have been processed. Recursivly call
-          // ModuleLayoutPass
-          // to process the next section of the module
+          // ModuleLayoutPass to process the next section of the module
           return ModuleLayoutPass(_, inst);
         }
       }
+
+      if (opcode == SpvOpVariable) {
+        const uint32_t* storage_class = inst->words + inst->operands[2].offset;
+        if (*storage_class == SpvStorageClassFunction) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                 << "Variables cannot have a function[7] storage class "
+                    "outside of a function";
+        }
+      }
     } else {
-      // Validate the function layout.
+      // This ensures no module instructions are called during function
+      // declarations
       switch (opcode) {
         case SpvOpCapability:
         case SpvOpExtension:
@@ -520,24 +537,117 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _,
               inst->words + inst->operands[2].offset;
           if (*storage_class != SpvStorageClassFunction)
             return _.diag(SPV_ERROR_INVALID_LAYOUT)
-                   << "All OpVariable instructions in a function must have a "
-                      "storage class of Function[7]";
+                   << "All Variable instructions in a function must have a "
+                      "storage class of function[7]";
+        } break;
+        default:
           break;
+      }
+      if (_.getLayoutStage() == kLayoutFunctionDeclarations) {
+        switch (opcode) {
+          case SpvOpFunction:
+            if (_.in_function_body()) {
+              return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                     << "Cannot declare a function in a function body";
+            }
+            spvCheckReturn(_.get_functions().RegisterFunction(
+                inst->result_id, inst->type_id,
+                inst->words[inst->operands[2].offset],
+                inst->words[inst->operands[3].offset]));
+            break;
+          case SpvOpFunctionParameter:
+            if (_.in_function_body() == false) {
+              return _.diag(SPV_ERROR_INVALID_LAYOUT) << "Function parameter "
+                                                         "instructions must be "
+                                                         "in a function body";
+            }
+            spvCheckReturn(_.get_functions().RegisterFunctionParameter(
+                inst->result_id, inst->type_id));
+            break;
+          case SpvOpLine:  // ??
+            break;
+          case SpvOpLabel:
+            if (_.in_function_body() == false) {
+              return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                     << "Label instructions must be in a function body";
+            }
+            _.progressToNextLayoutStageOrder();
+            spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
+                FunctionDecl::kFunctionDeclDefinition));
+            break;
+          case SpvOpFunctionEnd:
+            assert(_.get_functions().get_block_count() ==
+                       0  // NOTE: This should not happen
+                   &&
+                   "Function contains blocks in function declaration section.");
+            if (_.in_function_body() == false) {
+              return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                     << "Function end instructions must be in a function body";
+            }
+            spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
+                FunctionDecl::kFunctionDeclDeclaration));
+            spvCheckReturn(_.get_functions().RegisterFunctionEnd());
+            break;
+          default:
+            return _.diag(SPV_ERROR_INVALID_LAYOUT)
+                   << "A function must begin with a label";
+            break;
         }
-        default:
-          return SPV_SUCCESS;
       }
+      // NOTE: Function definitions are handled by the CFGPass
     }
   }
   return SPV_SUCCESS;
 }
 
-// Shame
-#define CHECK_RESULT(EXPRESSION)                \
-  do{                                           \
-    spv_result_t ret = EXPRESSION;              \
-    if(ret) return ret;                         \
-} while(false);
+// TODO(umar): Support for merge instructions
+// TODO(umar): Structured control flow checks
+spv_result_t CfgPass(ValidationState_t& _,
+                     const spv_parsed_instruction_t* inst) {
+  if (_.getLayoutStage() == kLayoutFunctionDefinitions) {
+    SpvOp opcode = inst->opcode;
+    switch (opcode) {
+      case SpvOpFunction:
+        spvCheckReturn(_.get_functions().RegisterFunction(
+            inst->result_id, inst->type_id,
+            inst->words[inst->operands[2].offset],
+            inst->words[inst->operands[3].offset]));
+        spvCheckReturn(_.get_functions().RegisterSetFunctionDeclType(
+            FunctionDecl::kFunctionDeclDefinition));
+        break;
+      case SpvOpFunctionParameter:
+        spvCheckReturn(_.get_functions().RegisterFunctionParameter(
+            inst->result_id, inst->type_id));
+        break;
+      case SpvOpFunctionEnd:
+        if (_.get_functions().get_block_count() == 0)
+          return _.diag(SPV_ERROR_INVALID_LAYOUT) << "Function declarations "
+                                                     "must appear before "
+                                                     "function definitions.";
+        spvCheckReturn(_.get_functions().RegisterFunctionEnd());
+        break;
+      case SpvOpLabel:
+        spvCheckReturn(_.get_functions().RegisterBlock(inst->result_id));
+        break;
+      case SpvOpBranch:
+      case SpvOpBranchConditional:
+      case SpvOpSwitch:
+      case SpvOpKill:
+      case SpvOpReturn:
+      case SpvOpReturnValue:
+      case SpvOpUnreachable:
+        spvCheckReturn(_.get_functions().RegisterBlockEnd());
+        break;
+      default:
+        if (_.in_block() == false) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT) << spvOpcodeString(opcode)
+                                                  << " must appear in a block";
+        }
+        break;
+    }
+  }
+  return SPV_SUCCESS;
+}
 
 spv_result_t ProcessInstructions(void* user_data,
                                  const spv_parsed_instruction_t* inst) {
@@ -549,16 +659,16 @@ spv_result_t ProcessInstructions(void* user_data,
 
   DebugInstructionPass(_, inst);
 
-  // TODO(umar): Perform CFG pass
   // TODO(umar): Perform data rules pass
   // TODO(umar): Perform instruction validation pass
-  CHECK_RESULT(ModuleLayoutPass(_, inst))
-  CHECK_RESULT(SsaPass(_, can_have_forward_declared_ids, inst))
+  spvCheckReturn(ModuleLayoutPass(_, inst));
+  spvCheckReturn(CfgPass(_, inst));
+  spvCheckReturn(SsaPass(_, can_have_forward_declared_ids, inst));
 
   return SPV_SUCCESS;
 }
 
-} // anonymous namespace
+}  // anonymous namespace
 
 spv_result_t spvValidate(const spv_const_context context,
                          const spv_const_binary binary, const uint32_t options,
@@ -581,12 +691,9 @@ spv_result_t spvValidate(const spv_const_context context,
   // NOTE: Parse the module and perform inline validation checks. These
   // checks do not require the the knowledge of the whole module.
   ValidationState_t vstate(pDiagnostic, options);
-  auto err = spvBinaryParse(context, &vstate, binary->code, binary->wordCount,
-                            setHeader, ProcessInstructions, pDiagnostic);
-
-  if (err) {
-    return err;
-  }
+  spvCheckReturn(spvBinaryParse(context, &vstate, binary->code,
+                                binary->wordCount, setHeader,
+                                ProcessInstructions, pDiagnostic));
 
   // TODO(umar): Add validation checks which require the parsing of the entire
   // module. Use the information from the processInstructions pass to make
index 51655fe..1f130c5 100644 (file)
@@ -28,6 +28,7 @@
 #include "validate_types.h"
 
 #include <algorithm>
+#include <cassert>
 #include <map>
 #include <string>
 #include <unordered_set>
@@ -115,16 +116,15 @@ const vector<vector<SpvOp>>& GetModuleOrder() {
 
 namespace libspirv {
 
-ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
-                                     uint32_t options)
-    : diagnostic_(diagnostic),
+ValidationState_t::ValidationState_t(spv_diagnostic* diag, uint32_t options)
+    : diagnostic_(diag),
       instruction_counter_(0),
       defined_ids_{},
       unresolved_forward_ids_{},
       validation_flags_(options),
       operand_names_{},
-      module_layout_order_stage_(0),
-      current_layout_stage_(ModuleLayoutSection::kModule) {}
+      current_layout_stage_(kLayoutCapabilities),
+      module_functions_(*this) {}
 
 spv_result_t ValidationState_t::defineId(uint32_t id) {
   if (defined_ids_.find(id) == end(defined_ids_)) {
@@ -186,23 +186,134 @@ ModuleLayoutSection ValidationState_t::getLayoutStage() const {
 }
 
 void ValidationState_t::progressToNextLayoutStageOrder() {
-  module_layout_order_stage_ +=
-      module_layout_order_stage_ < GetModuleOrder().size();
-  if (module_layout_order_stage_ >= GetModuleOrder().size()) {
-    current_layout_stage_ = libspirv::ModuleLayoutSection::kFunction;
+  if (current_layout_stage_ <= GetModuleOrder().size()) {
+    current_layout_stage_ =
+        static_cast<ModuleLayoutSection>(current_layout_stage_ + 1);
   }
 }
 
 bool ValidationState_t::isOpcodeInCurrentLayoutStage(SpvOp op) {
-  const vector<SpvOp>& currentStage =
-      GetModuleOrder()[module_layout_order_stage_];
+  const vector<SpvOp>& currentStage = GetModuleOrder()[current_layout_stage_];
   return end(currentStage) != find(begin(currentStage), end(currentStage), op);
 }
 
-libspirv::DiagnosticStream ValidationState_t::diag(
-    spv_result_t error_code) const {
+DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const {
   return libspirv::DiagnosticStream(
       {0, 0, static_cast<size_t>(instruction_counter_)}, diagnostic_,
       error_code);
 }
+
+Functions& ValidationState_t::get_functions() { return module_functions_; }
+
+bool ValidationState_t::in_function_body() const {
+  return module_functions_.in_function_body();
+}
+
+bool ValidationState_t::in_block() const {
+  return module_functions_.in_block();
+}
+
+Functions::Functions(ValidationState_t& module)
+    : module_(module), in_function_(false), in_block_(false) {}
+
+bool Functions::in_function_body() const { return in_function_; }
+
+bool Functions::in_block() const { return in_block_; }
+
+spv_result_t Functions::RegisterFunction(uint32_t id, uint32_t ret_type_id,
+                                         uint32_t function_control,
+                                         uint32_t function_type_id) {
+  assert(in_function_ == false &&
+         "Function instructions can not be declared in a function");
+  in_function_ = true;
+  id_.emplace_back(id);
+  type_id_.emplace_back(function_type_id);
+  declaration_type_.emplace_back(FunctionDecl::kFunctionDeclUnknown);
+  block_ids_.emplace_back();
+  variable_ids_.emplace_back();
+  parameter_ids_.emplace_back();
+
+  // TODO(umar): validate function type and type_id
+  (void)ret_type_id;
+  (void)function_control;
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t Functions::RegisterFunctionParameter(uint32_t id,
+                                                  uint32_t type_id) {
+  assert(in_function_ == true &&
+         "Function parameter instructions cannot be declared outside of a function");
+  if (in_block()) {
+    return module_.diag(SPV_ERROR_INVALID_LAYOUT)
+           << "Function parameters cannot be called in blocks";
+  }
+  if (block_ids_.back().size() != 0) {
+    return module_.diag(SPV_ERROR_INVALID_LAYOUT)
+           << "Function parameters must only appear immediatly after the "
+              "function definition";
+  }
+  // TODO(umar): Validate function parameter type order and count
+  // TODO(umar): Use these variables to validate parameter type
+  (void)id;
+  (void)type_id;
+  return SPV_SUCCESS;
+}
+
+spv_result_t Functions::RegisterSetFunctionDeclType(FunctionDecl type) {
+  assert(in_function_ == true && "Function can not be declared inside of another function");
+  if (declaration_type_.size() <= 1 || type == *(end(declaration_type_) - 2) ||
+      type == FunctionDecl::kFunctionDeclDeclaration) {
+    declaration_type_.back() = type;
+  } else if (type == FunctionDecl::kFunctionDeclDeclaration) {
+    return module_.diag(SPV_ERROR_INVALID_LAYOUT)
+           << "Function declartions must appear before function definitions";
+  } else {
+    declaration_type_.back() = type;
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t Functions::RegisterBlock(uint32_t id) {
+  assert(in_function_ == true && "Labels can only exsist in functions");
+  if (module_.getLayoutStage() ==
+      ModuleLayoutSection::kLayoutFunctionDeclarations) {
+    return module_.diag(SPV_ERROR_INVALID_LAYOUT)
+           << "Function declartions must appear before function definitions";
+  }
+  if (declaration_type_.back() != FunctionDecl::kFunctionDeclDefinition) {
+    // NOTE: This should not happen. We should know that this function is a
+    // definition at this point.
+    return module_.diag(SPV_ERROR_INTERNAL)
+           << "Function declaration type should have already been defined";
+  }
+
+  block_ids_.back().push_back(id);
+  in_block_ = true;
+  return SPV_SUCCESS;
+}
+
+spv_result_t Functions::RegisterFunctionEnd() {
+  assert(in_function_ == true &&
+         "Function end can only be called in functions");
+  if (in_block()) {
+    return module_.diag(SPV_ERROR_INVALID_LAYOUT)
+           << "Function end cannot be called in blocks";
+  }
+  in_function_ = false;
+  return SPV_SUCCESS;
+}
+
+spv_result_t Functions::RegisterBlockEnd() {
+  assert(in_block_ == true &&
+         "Branch instruction can only be called in a block");
+  in_block_ = false;
+  return SPV_SUCCESS;
+}
+
+size_t Functions::get_block_count() {
+  assert(in_function_ == true &&
+         "Branch instruction can only be called in a block");
+  return block_ids_.back().size();
+}
 }
index e57c58e..d11badd 100644 (file)
 
 namespace libspirv {
 
-// This enum represents the sections of a SPIRV module. The MODULE section
-// contains instructions who's scope spans the entire module. The FUNCTION
-// section includes SPIRV function and function definitions
-enum class ModuleLayoutSection {
-  kModule,    // < Module scope instructions are executed
-  kFunction,  // < Function scope instructions are executed
+// This enum represents the sections of a SPIRV module. See section 2.4
+// of the SPIRV spec for additional details of the order. The enumerant values
+// are in the same order as the vector returned by GetModuleOrder
+enum ModuleLayoutSection {
+  kLayoutCapabilities,          // < Section 2.4 #1
+  kLayoutExtensions,            // < Section 2.4 #2
+  kLayoutExtInstImport,         // < Section 2.4 #3
+  kLayoutMemoryModel,           // < Section 2.4 #4
+  kLayoutEntryPoint,            // < Section 2.4 #5
+  kLayoutExecutionMode,         // < Section 2.4 #6
+  kLayoutDebug1,                // < Section 2.4 #7 > 1
+  kLayoutDebug2,                // < Section 2.4 #7 > 2
+  kLayoutAnnotations,           // < Section 2.4 #8
+  kLayoutTypes,                 // < Section 2.4 #9
+  kLayoutFunctionDeclarations,  // < Section 2.4 #10
+  kLayoutFunctionDefinitions    // < Section 2.4 #11
+};
+
+enum class FunctionDecl {
+  kFunctionDeclUnknown,      // < Unknown function declaration
+  kFunctionDeclDeclaration,  // < Function declaration
+  kFunctionDeclDefinition    // < Function definition
+};
+
+class ValidationState_t;
+
+// This class manages all function declaration and definitions in a module. It
+// handles the state and id information while parsing a function in the SPIR-V
+// binary.
+//
+// NOTE: This class is designed to be a Structure of Arrays. Therefore each
+// member variable is a vector whose elements represent the values for the
+// corresponding function in a SPIR-V module. Variables that are not vector
+// types are used to manage the state while parsing the function.
+class Functions {
+ public:
+  Functions(ValidationState_t& module);
+
+  // Registers the function in the module. Subsequent instructions will be
+  // called against this function
+  spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id,
+                                uint32_t function_control,
+                                uint32_t function_type_id);
+
+  // Registers a function parameter in the current function
+  spv_result_t RegisterFunctionParameter(uint32_t id, uint32_t type_id);
+
+  // Register a function end instruction
+  spv_result_t RegisterFunctionEnd();
+
+  // Sets the declaration type of the current function
+  spv_result_t RegisterSetFunctionDeclType(FunctionDecl type);
+
+  // Registers a block in the current function. Subsequent block instructions
+  // will target this block
+  // @param id The ID of the label of the block
+  spv_result_t RegisterBlock(uint32_t id);
+
+  // Registers a variable in the current block
+  spv_result_t RegisterBlockVariable(uint32_t type_id, uint32_t id,
+                                     SpvStorageClass storage, uint32_t init_id);
+
+  spv_result_t RegisterBlockLoopMerge(uint32_t merge_id, uint32_t continue_id,
+                                      SpvLoopControlMask control);
+
+  spv_result_t RegisterBlockSelectionMerge(uint32_t merge_id,
+                                           SpvSelectionControlMask control);
+
+  // Registers the end of the block
+  spv_result_t RegisterBlockEnd();
+
+  // Returns the number of blocks in the current function being parsed
+  size_t get_block_count();
+
+  // Retuns true if called after a function instruction but before the
+  // function end instruction
+  bool in_function_body() const;
+
+  // Returns true if called after a label instruction but before a branch
+  // instruction
+  bool in_block() const;
+
+  libspirv::DiagnosticStream diag(spv_result_t error_code) const;
+
+ private:
+  // Parent module
+  ValidationState_t& module_;
+
+  // Funciton IDs in a module
+  std::vector<uint32_t> id_;
+
+  // OpTypeFunction IDs of each of the id_ functions
+  std::vector<uint32_t> type_id_;
+
+  // The type of declaration of each function
+  std::vector<FunctionDecl> declaration_type_;
+
+  // TODO(umar): Probably needs better abstractions
+  // The beginning of the block of functions
+  std::vector<std::vector<uint32_t>> block_ids_;
+
+  // The variable IDs of the functions
+  std::vector<std::vector<uint32_t>> variable_ids_;
+
+  // The function parameter ids of the functions
+  std::vector<std::vector<uint32_t>> parameter_ids_;
+
+  // NOTE: See correspoding getter functions
+  bool in_function_;
+  bool in_block_;
 };
 
 class ValidationState_t {
@@ -94,6 +198,17 @@ class ValidationState_t {
 
   libspirv::DiagnosticStream diag(spv_result_t error_code) const;
 
+  // Returns the function states
+  Functions& get_functions();
+
+  // Retuns true if the called after a function instruction but before the
+  // function end instruction
+  bool in_function_body() const;
+
+  // Returns true if called after a label instruction but before a branch
+  // instruction
+  bool in_block() const;
+
  private:
   spv_diagnostic* diagnostic_;
   // Tracks the number of instructions evaluated by the validator
@@ -110,12 +225,12 @@ class ValidationState_t {
 
   std::map<uint32_t, std::string> operand_names_;
 
-  // The stage which is being processed by the validation. Partially based on
-  // Section 2.4. Logical Layout of a Module
-  uint32_t module_layout_order_stage_;
-
   // The section of the code being processed
   ModuleLayoutSection current_layout_stage_;
+
+  Functions module_functions_;
+
+  std::vector<SpvCapability> module_capabilities_;
 };
 }
 
index e902dd3..39f05d1 100644 (file)
@@ -27,6 +27,7 @@
 // Validation tests for Logical Layout
 
 #include "gmock/gmock.h"
+#include "source/diagnostic.h"
 #include "UnitSPIRV.h"
 #include "ValidateFixtures.h"
 
@@ -36,6 +37,7 @@
 #include <utility>
 
 using std::function;
+using std::ostream;
 using std::ostream_iterator;
 using std::pair;
 using std::stringstream;
@@ -46,30 +48,37 @@ using std::vector;
 
 using ::testing::HasSubstr;
 
-using pred_type = function<bool(int)>;
+using libspirv::spvResultToString;
+
+using pred_type = function<spv_result_t(int)>;
 using ValidateLayout =
     spvtest::ValidateBase<tuple<int, tuple<string, pred_type, pred_type>>,
                           SPV_VALIDATE_LAYOUT_BIT>;
-
 namespace {
 
 // returns true if order is equal to VAL
-template <int VAL>
-bool Equals(int order) {
-  return order == VAL;
+template <int VAL, spv_result_t RET = SPV_ERROR_INVALID_LAYOUT>
+spv_result_t Equals(int order) {
+  return order == VAL ? SPV_SUCCESS : RET;
 }
 
 // returns true if order is between MIN and MAX(inclusive)
-template <int MIN, int MAX>
+template <int MIN, int MAX, spv_result_t RET = SPV_ERROR_INVALID_LAYOUT>
 struct Range {
-  bool operator()(int order) { return order >= MIN && order <= MAX; }
+  Range(bool inverse = false) : inverse_(inverse) {}
+  spv_result_t operator()(int order) {
+    return (inverse_ ^ (order >= MIN && order <= MAX)) ? SPV_SUCCESS : RET;
+  }
+
+ private:
+  bool inverse_;
 };
 
 template <typename... T>
-bool RangeSet(int order) {
-  for (bool val : {T()(order)...})
-    if (!val) return val;
-  return false;
+spv_result_t InvalidSet(int order) {
+  for (spv_result_t val : {T(true)(order)...})
+    if (val != SPV_SUCCESS) return val;
+  return SPV_SUCCESS;
 }
 
 // SPIRV source used to test the logical layout
@@ -92,41 +101,42 @@ const vector<string>& getInstructions() {
     "OpMemberDecorate %struct 1 RowMajor",
     "%dgrp   = OpDecorationGroup",
     "OpGroupDecorate %dgrp %mat33 %mat44",
-    "%intt    OpTypeInt 32 1",
-    "%floatt  OpTypeFloat 32",
-    "%voidt   OpTypeVoid",
-    "%boolt   OpTypeBool",
-    "%vec4    OpTypeVector %intt 4",
-    "%vec3    OpTypeVector %intt 3",
-    "%mat33   OpTypeMatrix %vec3 3",
-    "%mat44   OpTypeMatrix %vec4 4",
-    "%struct  OpTypeStruct %intt %mat33",
-    "%vfunct = OpTypeFunction %voidt",
-    "%viifunct =  OpTypeFunction %voidt %intt %intt",
-    "%one      =  OpConstant %intt 1",
+    "%intt     = OpTypeInt 32 1",
+    "%floatt   = OpTypeFloat 32",
+    "%voidt    = OpTypeVoid",
+    "%boolt    = OpTypeBool",
+    "%vec4     = OpTypeVector %intt 4",
+    "%vec3     = OpTypeVector %intt 3",
+    "%mat33    = OpTypeMatrix %vec3 3",
+    "%mat44    = OpTypeMatrix %vec4 4",
+    "%struct   = OpTypeStruct %intt %mat33",
+    "%vfunct   = OpTypeFunction %voidt",
+    "%viifunct = OpTypeFunction %voidt %intt %intt",
+    "%one      = OpConstant %intt 1",
     // TODO(umar): OpConstant fails because the type is not defined
     // TODO(umar): OpGroupMemberDecorate
     "OpLine %str 3 4",
-    "%func   = OpFunction %voidt None %vfunct",
+    "%func     = OpFunction %voidt None %vfunct",
     "OpFunctionEnd",
-    "%func2   = OpFunction %voidt None %viifunct",
-    "%funcp1 = OpFunctionParameter %intt",
-    "%funcp2 = OpFunctionParameter %intt",
-    "%fLabel = OpLabel",
-    "          OpNop",
-    "OpReturn",
+    "%func2    = OpFunction %voidt None %viifunct",
+    "%funcp1   = OpFunctionParameter %intt",
+    "%funcp2   = OpFunctionParameter %intt",
+    "%fLabel   = OpLabel",
+    "            OpNop",
+    "            OpReturn",
     "OpFunctionEnd"
   };
   return instructions;
 }
 
-pred_type All = Range<0, 1000>();
+static const int kRangeEnd = 1000;
+pred_type All = Range<0, kRangeEnd>();
 
 INSTANTIATE_TEST_CASE_P(InstructionsOrder,
     ValidateLayout,
     ::testing::Combine(::testing::Range((int)0, (int)getInstructions().size()),
     //                                   | Instruction              | Line(s) valid     | Lines to compile
-    ::testing::Values( make_tuple( string("OpCapability")           , Equals<0>         , All)
+    ::testing::Values( make_tuple(string("OpCapability")            , Equals<0>         , All)
                      , make_tuple(string("OpExtension")             , Equals<1>         , All)
                      , make_tuple(string("OpExtInstImport")         , Equals<2>         , All)
                      , make_tuple(string("OpMemoryModel")           , Equals<3>         , All)
@@ -149,9 +159,11 @@ INSTANTIATE_TEST_CASE_P(InstructionsOrder,
                      , make_tuple(string("OpTypeVector %intt 4")    , Range<16, 28>()   , All)
                      , make_tuple(string("OpTypeMatrix %vec4 4")    , Range<16, 28>()   , All)
                      , make_tuple(string("OpTypeStruct")            , Range<16, 28>()   , All)
-                     , make_tuple(string("%vfunct = OpTypeFunction"), Range<16, 28>()   , All)
-                     , make_tuple(string("OpConstant")              , Range<19, 28>()   , static_cast<pred_type>(Range<19, 100>()))
-                   //, make_tuple(string("OpLabel")                 , RangeSet<Range<29,31>, Range<35, 36>, >   , All)
+                     , make_tuple(string("%vfunct   = OpTypeFunction"), Range<16, 28>()   , All)
+                     , make_tuple(string("OpConstant")              , Range<19, 28>()   , static_cast<pred_type>(Range<19, kRangeEnd>()))
+                     , make_tuple(string("OpLabel")                 , Equals<34>        , All)
+                     , make_tuple(string("OpNop")                   , Equals<35>        , All)
+                     , make_tuple(string("OpReturn")                , Equals<36>        , All)
     )));
 // clang-format on
 
@@ -188,7 +200,7 @@ TEST_P(ValidateLayout, Layout) {
   tie(instruction, pred, test_pred) = testCase;
 
   // Skip test which break the code generation
-  if (!test_pred(order)) return;
+  if (test_pred(order)) return;
 
   vector<string> code = GenerateCode(instruction, order);
 
@@ -197,16 +209,18 @@ TEST_P(ValidateLayout, Layout) {
 
   // printf("code: \n%s\n", ss.str().c_str());
   CompileSuccessfully(ss.str());
-  if (pred(order)) {
-    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions())
-        << "Order: " << order << "\nInstruction: " << instruction;
-  } else {
-    ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions())
-        << "Order: " << order << "\nInstruction: " << instruction;
-  }
+  spv_result_t result;
+  // clang-format off
+  ASSERT_EQ(pred(order), result = ValidateInstructions())
+    << "Actual: "        << spvResultToString(result)
+    << "\nExpected: "    << spvResultToString(pred(order))
+    << "\nOrder: "       << order
+    << "\nInstruction: " << instruction
+    << "\nCode: \n"      << ss.str();
+  // clang-format on
 }
 
-TEST_F(ValidateLayout, DISABLED_MemoryModelMissing) {
+TEST_F(ValidateLayout, MemoryModelMissing) {
   string str = R"(
     OpCapability Matrix
     OpExtension "TestExtension"
@@ -219,6 +233,114 @@ TEST_F(ValidateLayout, DISABLED_MemoryModelMissing) {
   ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions());
 }
 
+TEST_F(ValidateLayout, VariableFunctionStorageGood) {
+  char str[] = R"(
+          OpMemoryModel Logical GLSL450
+          OpDecorate %var Restrict
+%intt   = OpTypeInt 32 1
+%voidt  = OpTypeVoid
+%vfunct = OpTypeFunction %voidt
+%ptrt   = OpTypePointer Function %intt
+%func   = OpFunction %voidt None %vfunct
+%funcl  = OpLabel
+%var    = OpVariable %ptrt Function
+          OpReturn
+          OpFunctionEnd
+)";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateLayout, VariableFunctionStorageBad) {
+  char str[] = R"(
+          OpMemoryModel Logical GLSL450
+          OpDecorate %var Restrict
+%intt   = OpTypeInt 32 1
+%voidt  = OpTypeVoid
+%vfunct = OpTypeFunction %voidt
+%ptrt   = OpTypePointer Function %intt
+%var    = OpVariable %ptrt Function     ; Invalid storage class for OpVariable
+%func   = OpFunction %voidt None %vfunct
+%funcl  = OpLabel
+          OpNop
+          OpReturn
+          OpFunctionEnd
+)";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions());
+}
+
+TEST_F(ValidateLayout, FunctionDefinitionBeforeDeclarationBad) {
+  char str[] = R"(
+           OpMemoryModel Logical GLSL450
+           OpDecorate %var Restrict
+%intt    = OpTypeInt 32 1
+%voidt   = OpTypeVoid
+%vfunct  = OpTypeFunction %voidt
+%vifunct = OpTypeFunction %voidt %intt
+%ptrt    = OpTypePointer Function %intt
+%func    = OpFunction %voidt None %vfunct
+%funcl   = OpLabel
+           OpNop
+           OpReturn
+           OpFunctionEnd
+%func2   = OpFunction %voidt None %vifunct ; must appear before definition
+%func2p  = OpFunctionParameter %intt
+           OpFunctionEnd
+)";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions());
+}
+
+// TODO(umar): Passes but gives incorrect error message. Should be fixed after
+// type checking
+TEST_F(ValidateLayout, LabelBeforeFunctionParameterBad) {
+  char str[] = R"(
+           OpMemoryModel Logical GLSL450
+           OpDecorate %var Restrict
+%intt    = OpTypeInt 32 1
+%voidt   = OpTypeVoid
+%vfunct  = OpTypeFunction %voidt
+%vifunct = OpTypeFunction %voidt %intt
+%ptrt    = OpTypePointer Function %intt
+%func    = OpFunction %voidt None %vifunct
+%funcl   = OpLabel                    ; Label appears before function parameter
+%func2p  = OpFunctionParameter %intt
+           OpNop
+           OpReturn
+           OpFunctionEnd
+)";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions());
+}
+
+TEST_F(ValidateLayout, FuncParameterNotImmediatlyAfterFuncBad) {
+  char str[] = R"(
+           OpMemoryModel Logical GLSL450
+           OpDecorate %var Restrict
+%intt    = OpTypeInt 32 1
+%voidt   = OpTypeVoid
+%vfunct  = OpTypeFunction %voidt
+%vifunct = OpTypeFunction %voidt %intt
+%ptrt    = OpTypePointer Function %intt
+%func    = OpFunction %voidt None %vifunct
+%funcl   = OpLabel
+           OpNop
+           OpBranch %next
+%func2p  = OpFunctionParameter %intt        ;FunctionParameter appears in a function but not immediatly afterwards
+%next    = OpLabel
+           OpNop
+           OpReturn
+           OpFunctionEnd
+)";
+
+  CompileSuccessfully(str);
+  ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions());
+}
+
 // TODO(umar): Test optional instructions
-// TODO(umar): Test logical layout of functions
 }
index 5ca35df..1eaf148 100644 (file)
@@ -332,6 +332,7 @@ TEST_F(Validate, FunctionCallGood) {
 %5    =  OpFunction %1 None %4
 %6    =  OpLabel
 %7    =  OpFunctionCall %1 %9 %four %five
+         OpReturn
          OpFunctionEnd
 )";
   CompileSuccessfully(str);
@@ -351,6 +352,7 @@ TEST_F(Validate, ForwardFunctionCallGood) {
 %5    =  OpFunction %1 None %4
 %6    =  OpLabel
 %7    =  OpFunctionCall %1 %9 %four %five
+         OpReturn
          OpFunctionEnd
 %9    =  OpFunction %1 None %8
 %10   =  OpFunctionParameter %2
@@ -770,6 +772,7 @@ TEST_P(Validate, ForwardGetKernelGood) {
   string str = kHeader + kBasicTypes + kKernelTypesAndConstants +
                R"(
             %main    = OpFunction %voidt None %vfunct
+            %mainl   = OpLabel
                 )"
             + kKernelSetup + " %numsg = "
             + instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign"
index 3d38083..45735a2 100644 (file)
@@ -61,9 +61,10 @@ void ValidateBase<T, OPTIONS>::TearDown() {
 template <typename T, uint32_t OPTIONS>
 void ValidateBase<T, OPTIONS>::CompileSuccessfully(std::string code) {
   spv_diagnostic diagnostic = nullptr;
-  EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context_, code.c_str(), code.size(),
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context_, code.c_str(), code.size(),
                                          &binary_, &diagnostic))
-      << "SPIR-V could not be compiled into binary:" << code;
+      << "ERROR: " << diagnostic->error
+      << "\nSPIR-V could not be compiled into binary:\n" << code;
 }
 
 template <typename T, uint32_t OPTIONS>
@@ -77,12 +78,17 @@ std::string ValidateBase<T, OPTIONS>::getDiagnosticString() {
   return std::string(diagnostic_->error);
 }
 
+template <typename T, uint32_t OPTIONS>
+spv_position_t ValidateBase<T, OPTIONS>::getErrorPosition() {
+  return diagnostic_->position;
+}
+
 template class spvtest::ValidateBase<std::pair<std::string, bool>,
                                      SPV_VALIDATE_SSA_BIT |
                                          SPV_VALIDATE_LAYOUT_BIT>;
 template class spvtest::ValidateBase<bool, SPV_VALIDATE_SSA_BIT>;
 template class spvtest::ValidateBase<
-    std::tuple<int, std::tuple<std::string, std::function<bool(int)>,
-                               std::function<bool(int)>>>,
+    std::tuple<int, std::tuple<std::string, std::function<spv_result_t(int)>,
+                               std::function<spv_result_t(int)>>>,
     SPV_VALIDATE_LAYOUT_BIT>;
 }
index 21dc483..c58b068 100644 (file)
@@ -52,6 +52,7 @@ class ValidateBase : public ::testing::Test,
   spv_result_t ValidateInstructions();
 
   std::string getDiagnosticString();
+  spv_position_t getErrorPosition();
 
   spv_context context_;
   spv_binary binary_;