Introducing a new flow for running the Validator.
authorEhsan Nasiri <ehsann@google.com>
Wed, 11 Jan 2017 20:03:53 +0000 (15:03 -0500)
committerDavid Neto <dneto@google.com>
Fri, 13 Jan 2017 21:07:03 +0000 (16:07 -0500)
We are adding a new API which can be called to run the SPIR-V validator,
and retrieve the ValidationState_t object. This is very useful for
unit testing.

I have also added basic unit tests that demonstrate usage of this flow
and ease of use to verify correctness.

source/validate.cpp
source/validate.h
test/val/CMakeLists.txt
test/val/val_fixtures.cpp
test/val/val_fixtures.h
test/val/val_validation_state_test.cpp [new file with mode: 0644]

index 45ae5a4..fe6c4c4 100644 (file)
@@ -182,64 +182,58 @@ spv_result_t spvValidate(const spv_const_context context,
   return spvValidateBinary(context, binary->code, binary->wordCount,
                            pDiagnostic);
 }
-spv_result_t spvValidateBinary(const spv_const_context context,
-                               const uint32_t* words, const size_t num_words,
-                               spv_diagnostic* pDiagnostic) {
-  spv_context_t hijack_context = *context;
 
+spv_result_t ValidateBinaryUsingContextAndValidationState(
+    const spv_context_t& context, const uint32_t* words, const size_t num_words,
+    spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
   spv_const_binary binary = new spv_const_binary_t{words, num_words};
-  if (pDiagnostic) {
-    *pDiagnostic = nullptr;
-    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
-  }
 
   spv_endianness_t endian;
   spv_position_t position = {};
   if (spvBinaryEndianness(binary, &endian)) {
-    return libspirv::DiagnosticStream(position, hijack_context.consumer,
+    return libspirv::DiagnosticStream(position, context.consumer,
                                       SPV_ERROR_INVALID_BINARY)
            << "Invalid SPIR-V magic number.";
   }
 
   spv_header_t header;
   if (spvBinaryHeaderGet(binary, endian, &header)) {
-    return libspirv::DiagnosticStream(position, hijack_context.consumer,
+    return libspirv::DiagnosticStream(position, context.consumer,
                                       SPV_ERROR_INVALID_BINARY)
            << "Invalid SPIR-V header.";
   }
 
   // NOTE: Parse the module and perform inline validation checks. These
   // checks do not require the the knowledge of the whole module.
-  ValidationState_t vstate(&hijack_context);
-  if (auto error = spvBinaryParse(&hijack_context, &vstate, words, num_words,
+  if (auto error = spvBinaryParse(&context, vstate, words, num_words,
                                   setHeader, ProcessInstruction, pDiagnostic))
     return error;
 
-  if (vstate.in_function_body())
-    return vstate.diag(SPV_ERROR_INVALID_LAYOUT)
+  if (vstate->in_function_body())
+    return vstate->diag(SPV_ERROR_INVALID_LAYOUT)
            << "Missing OpFunctionEnd at end of module.";
 
   // TODO(umar): Add validation checks which require the parsing of the entire
   // module. Use the information from the ProcessInstruction pass to make the
   // checks.
-  if (vstate.unresolved_forward_id_count() > 0) {
+  if (vstate->unresolved_forward_id_count() > 0) {
     stringstream ss;
-    vector<uint32_t> ids = vstate.UnresolvedForwardIds();
+    vector<uint32_t> ids = vstate->UnresolvedForwardIds();
 
     transform(begin(ids), end(ids), ostream_iterator<string>(ss, " "),
-              bind(&ValidationState_t::getIdName, std::ref(vstate), _1));
+              bind(&ValidationState_t::getIdName, std::ref(*vstate), _1));
 
     auto id_str = ss.str();
-    return vstate.diag(SPV_ERROR_INVALID_ID)
+    return vstate->diag(SPV_ERROR_INVALID_ID)
            << "The following forward referenced IDs have not been defined:\n"
            << id_str.substr(0, id_str.size() - 1);
   }
 
   // CFG checks are performed after the binary has been parsed
   // and the CFGPass has collected information about the control flow
-  if (auto error = PerformCfgChecks(vstate)) return error;
-  if (auto error = UpdateIdUse(vstate)) return error;
-  if (auto error = CheckIdDefinitionDominateUse(vstate)) return error;
+  if (auto error = PerformCfgChecks(*vstate)) return error;
+  if (auto error = UpdateIdUse(*vstate)) return error;
+  if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
 
   // NOTE: Copy each instruction for easier processing
   std::vector<spv_instruction_t> instructions;
@@ -258,7 +252,39 @@ spv_result_t spvValidateBinary(const spv_const_context context,
 
   position.index = SPV_INDEX_INSTRUCTION;
   return spvValidateIDs(instructions.data(), instructions.size(),
-                        hijack_context.opcode_table,
-                        hijack_context.operand_table,
-                        hijack_context.ext_inst_table, vstate, &position);
+                        context.opcode_table,
+                        context.operand_table,
+                        context.ext_inst_table, *vstate, &position);
 }
+
+spv_result_t spvValidateBinary(const spv_const_context context,
+                               const uint32_t* words, const size_t num_words,
+                               spv_diagnostic* pDiagnostic) {
+  spv_context_t hijack_context = *context;
+  if (pDiagnostic) {
+    *pDiagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
+  }
+
+  // Create the ValidationState using the context.
+  ValidationState_t vstate(&hijack_context);
+
+  return ValidateBinaryUsingContextAndValidationState(
+      hijack_context, words, num_words, pDiagnostic, &vstate);
+}
+
+spv_result_t spvtools::ValidateBinaryAndKeepValidationState(
+    const spv_const_context context, const uint32_t* words,
+    const size_t num_words, spv_diagnostic* pDiagnostic,
+    std::unique_ptr<ValidationState_t>* vstate) {
+  spv_context_t hijack_context = *context;
+  if (pDiagnostic) {
+    *pDiagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
+  }
+
+  vstate->reset(new ValidationState_t(&hijack_context));
+  return ValidateBinaryUsingContextAndValidationState(
+      hijack_context, words, num_words, pDiagnostic, vstate->get());
+}
+
index 258b0eb..3237e6c 100644 (file)
@@ -190,4 +190,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions,
                             spv_position position,
                             const spvtools::MessageConsumer& consumer);
 
+namespace spvtools {
+// Performs validation for the SPIRV-V module binary.
+// The main difference between this API and spvValidateBinary is that the
+// "Validation State" is not destroyed upon function return; it lives on and is
+// pointed to by the vstate unique_ptr.
+spv_result_t ValidateBinaryAndKeepValidationState(
+    const spv_const_context context, const uint32_t* words,
+    const size_t num_words, spv_diagnostic* pDiagnostic,
+    std::unique_ptr<libspirv::ValidationState_t>* vstate);
+}  // namespace spvtools
+
 #endif  // LIBSPIRV_VALIDATE_H_
index 5ac1302..1e4becd 100644 (file)
@@ -75,3 +75,9 @@ add_spvtools_unittest(TARGET val_limits
   LIBS ${SPIRV_TOOLS}
 )
 
+add_spvtools_unittest(TARGET val_validation_state
+       SRCS val_validation_state_test.cpp
+       ${VAL_TEST_COMMON_SRCS}
+  LIBS ${SPIRV_TOOLS}
+)
+
index 2db99a8..b017433 100644 (file)
@@ -68,6 +68,14 @@ spv_result_t ValidateBase<T>::ValidateInstructions(spv_target_env env) {
 }
 
 template <typename T>
+spv_result_t ValidateBase<T>::ValidateAndRetrieveValidationState(
+    spv_target_env env) {
+  return spvtools::ValidateBinaryAndKeepValidationState(
+      ScopedContext(env).context, get_const_binary()->code,
+      get_const_binary()->wordCount, &diagnostic_, &vstate_);
+}
+
+template <typename T>
 std::string ValidateBase<T>::getDiagnosticString() {
   return std::string(diagnostic_->error);
 }
index bb4fe18..1d94705 100644 (file)
@@ -18,6 +18,7 @@
 #define LIBSPIRV_TEST_VALIDATE_FIXTURES_H_
 
 #include "unit_spirv.h"
+#include "source/val/validation_state.h"
 
 namespace spvtest {
 
@@ -45,11 +46,17 @@ class ValidateBase : public ::testing::Test,
   // spvValidate function
   spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
 
+  // Performs validation. Returns the status and stores validation state into
+  // the vstate_ member.
+  spv_result_t ValidateAndRetrieveValidationState(
+      spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
+
   std::string getDiagnosticString();
   spv_position_t getErrorPosition();
 
   spv_binary binary_;
   spv_diagnostic diagnostic_;
+  std::unique_ptr<libspirv::ValidationState_t> vstate_;
 };
 }
 #endif
diff --git a/test/val/val_validation_state_test.cpp b/test/val/val_validation_state_test.cpp
new file mode 100644 (file)
index 0000000..5eb09f7
--- /dev/null
@@ -0,0 +1,106 @@
+// Copyright (c) 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Basic tests for the ValidationState_t datastructure.
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "unit_spirv.h"
+#include "val_fixtures.h"
+
+namespace {
+
+using std::string;
+using ::testing::HasSubstr;
+
+using ValidationStateTest = spvtest::ValidateBase<bool>;
+
+const char header[] =
+    " OpCapability Shader"
+    " OpCapability Linkage"
+    " OpMemoryModel Logical GLSL450 ";
+
+const char kVoidFVoid[] =
+    " %void   = OpTypeVoid"
+    " %void_f = OpTypeFunction %void"
+    " %func   = OpFunction %void None %void_f"
+    " %label  = OpLabel"
+    "           OpReturn"
+    "           OpFunctionEnd ";
+
+// Tests that the instruction count in ValidationState is correct.
+TEST_F(ValidationStateTest, CheckNumInstructions) {
+  string spirv = string(header) + "%int = OpTypeInt 32 0";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
+}
+
+// Tests that the number of global variables in ValidationState is correct.
+TEST_F(ValidationStateTest, CheckNumGlobalVars) {
+  string spirv = string(header) + R"(
+     %int = OpTypeInt 32 0
+%_ptr_int = OpTypePointer Input %int
+   %var_1 = OpVariable %_ptr_int Input
+   %var_2 = OpVariable %_ptr_int Input
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_EQ(unsigned(2), vstate_->num_global_vars());
+}
+
+// Tests that the number of local variables in ValidationState is correct.
+TEST_F(ValidationStateTest, CheckNumLocalVars) {
+  string spirv = string(header) + R"(
+ %int      = OpTypeInt 32 0
+ %_ptr_int = OpTypePointer Function %int
+ %voidt    = OpTypeVoid
+ %funct    = OpTypeFunction %voidt
+ %main     = OpFunction %voidt None %funct
+ %entry    = OpLabel
+ %var_1    = OpVariable %_ptr_int Function
+ %var_2    = OpVariable %_ptr_int Function
+ %var_3    = OpVariable %_ptr_int Function
+ OpReturn
+ OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_EQ(unsigned(3), vstate_->num_local_vars());
+}
+
+// Tests that the "id bound" in ValidationState is correct.
+TEST_F(ValidationStateTest, CheckIdBound) {
+  string spirv = string(header) + R"(
+ %int      = OpTypeInt 32 0
+ %voidt    = OpTypeVoid
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_EQ(unsigned(3), vstate_->getIdBound());
+}
+
+// Tests that the entry_points in ValidationState is correct.
+TEST_F(ValidationStateTest, CheckEntryPoints) {
+  string spirv = string(header) + " OpEntryPoint Vertex %func \"shader\"" +
+                 string(kVoidFVoid);
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+  EXPECT_EQ(size_t(1), vstate_->entry_points().size());
+  EXPECT_EQ(SpvOpFunction,
+            vstate_->FindDef(vstate_->entry_points()[0])->opcode());
+}
+
+}  // anonymous namespace