Adding validation code for PtrAccessChain.
authorEhsan Nasiri <ehsann@google.com>
Thu, 22 Dec 2016 19:31:21 +0000 (14:31 -0500)
committerEhsan Nasiri <ehsann@google.com>
Tue, 3 Jan 2017 21:36:35 +0000 (16:36 -0500)
Validation for OpPtrAccessChain is similar to OpAccessChain with the
following difference: OpPtrAccessChain takes an extra argument (word 4)
which is the Element <id> argument.

Validation for OpInBoundsPtrAccessChain is also similar to OpPtrAccessChain.

Also added tests for all access chain instructions:
Modified the existing parameterized tests to accommodate OpPtrAccessChain and
OpInBoundsPtrAccessChain.

Also fixed a typo in previous commits.

source/validate_id.cpp
test/val/val_id_test.cpp

index c1523e0..1ab7327 100644 (file)
@@ -1251,7 +1251,7 @@ bool idUsage::isValid<SpvOpAccessChain>(const spv_instruction_t* inst,
     return false;
   }
   if (num_indexes <= 0) {
-    DIAG(resultTypeIndex) << "No Indexes were passes to " << instr_name << ".";
+    DIAG(resultTypeIndex) << "No Indexes were passed to " << instr_name << ".";
     return false;
   }
   // Indexes walk the type hierarchy to the desired depth, potentially down to
@@ -1344,6 +1344,28 @@ bool idUsage::isValid<SpvOpInBoundsAccessChain>(
   return isValid<SpvOpAccessChain>(inst, opcodeEntry);
 }
 
+template <>
+bool idUsage::isValid<SpvOpPtrAccessChain>(const spv_instruction_t* inst,
+                                           const spv_opcode_desc opcodeEntry) {
+  // OpPtrAccessChain's validation rules are similar to OpAccessChain, with one
+  // difference: word 4 must be id of an integer (Element <id>).
+  // The grammar guarantees that there are at least 5 words in the instruction
+  // (i.e. if there are fewer than 5 words, the SPIR-V code will not compile.)
+  int elem_index = 4;
+  // We can remove the Element <id> from the instruction words, and simply call
+  // the validation code of OpAccessChain.
+  spv_instruction_t new_inst = *inst;
+  new_inst.words.erase(new_inst.words.begin() + elem_index);
+  return isValid<SpvOpAccessChain>(&new_inst, opcodeEntry);
+}
+
+template <>
+bool idUsage::isValid<SpvOpInBoundsPtrAccessChain>(
+    const spv_instruction_t* inst, const spv_opcode_desc opcodeEntry) {
+  // Has the same validation rules as OpPtrAccessChain
+  return isValid<SpvOpPtrAccessChain>(inst, opcodeEntry);
+}
+
 #if 0
 template <>
 bool idUsage::isValid<SpvOpArrayLength>(const spv_instruction_t *inst,
@@ -2572,6 +2594,8 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
     CASE(OpCopyMemorySized)
     CASE(OpAccessChain)
     CASE(OpInBoundsAccessChain)
+    CASE(OpPtrAccessChain)
+    CASE(OpInBoundsPtrAccessChain)
     TODO(OpArrayLength)
     TODO(OpGenericPtrMemSemantics)
     CASE(OpFunction)
index 203b81d..bc26b33 100644 (file)
@@ -1895,12 +1895,19 @@ string opAccessChainSpirvSetup = R"(
 // OpInBoundsPtrAccessChain
 using AccessChainInstructionTest = spvtest::ValidateBase<std::string>;
 
+// Determines whether the access chain instruction requires the 'element id'
+// argument.
+bool AccessChainRequiresElemId(const std::string& instr) {
+  return (instr == "OpPtrAccessChain" || instr == "OpInBoundsPtrAccessChain");
+}
+
 // Valid: Access a float in a matrix using an access chain instruction.
 TEST_P(AccessChainInstructionTest, AccessChainGood) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup +
                  "%float_entry = " + instr +
-                 R"( %_ptr_Private_float %my_matrix %int_0 %int_1
+                 R"( %_ptr_Private_float %my_matrix )" + elem + R"(%int_0 %int_1
               OpReturn
               OpFunctionEnd
           )";
@@ -1911,9 +1918,10 @@ TEST_P(AccessChainInstructionTest, AccessChainGood) {
 // Invalid. The result type of an access chain instruction must be a pointer.
 TEST_P(AccessChainInstructionTest, AccessChainResultTypeBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %float_entry = )" +
-                 instr + R"( %float %my_matrix %int_0 %int_1
+                 instr + R"( %float %my_matrix )" + elem + R"(%int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -1929,9 +1937,10 @@ OpFunctionEnd
 // Invalid. The base type of an access chain instruction must be a pointer.
 TEST_P(AccessChainInstructionTest, AccessChainBaseTypeVoidBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %float_entry = )" +
-                 instr + R"( %_ptr_Private_float %void %int_0 %int_1
+                 instr + " %_ptr_Private_float %void " + elem + R"(%int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -1946,9 +1955,11 @@ OpFunctionEnd
 // Invalid. The base type of an access chain instruction must be a pointer.
 TEST_P(AccessChainInstructionTest, AccessChainBaseTypeNonPtrVariableBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_float %_ptr_Private_float %int_0 %int_1
+                 R"( %_ptr_Private_float %_ptr_Private_float )" + elem +
+                 R"(%int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -1964,9 +1975,11 @@ OpFunctionEnd
 TEST_P(AccessChainInstructionTest,
        AccessChainResultAndBaseStorageClassDoesntMatchBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Function_float %my_matrix %int_0 %int_1
+                 R"( %_ptr_Function_float %my_matrix )" + elem +
+                 R"(%int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -1983,9 +1996,10 @@ OpFunctionEnd
 TEST_P(AccessChainInstructionTest,
        AccessChainBasePtrNotPointingToCompositeBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_float %my_float_var %int_0
+                 R"( %_ptr_Private_float %my_float_var )" + elem + R"(%int_0
 OpReturn
 OpFunctionEnd
   )";
@@ -2000,13 +2014,14 @@ OpFunctionEnd
 // Invalid. No Indexes passed to the access chain instruction.
 TEST_P(AccessChainInstructionTest, AccessChainMissingIndexesBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_float %my_float_var
+                 R"( %_ptr_Private_float %my_float_var )" + elem + R"(
 OpReturn
 OpFunctionEnd
   )";
-  const std::string expected_err = "No Indexes were passes to " + instr;
+  const std::string expected_err = "No Indexes were passed to " + instr;
   CompileSuccessfully(spirv);
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err));
@@ -2014,6 +2029,8 @@ OpFunctionEnd
 
 // Valid: 255 indexes passed to the access chain instruction. Limit is 255.
 TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) {
+  const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
   int depth = 255;
   std::string header = kGLSL450MemoryModel + opAccessChainSpirvSetup;
   header.erase(header.find("%func"));
@@ -2038,7 +2055,7 @@ TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) {
   )";
 
   // AccessChain with 'n' indexes (n = depth)
-  spirv << "%entry = " << GetParam() << " %_ptr_Uniform_float %deep_var";
+  spirv << "%entry = " << instr << " %_ptr_Uniform_float %deep_var" << elem;
   for (int i = 0; i < depth; ++i) {
     spirv << " %int_0";
   }
@@ -2055,9 +2072,10 @@ TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) {
 // Invalid: 256 indexes passed to the access chain instruction. Limit is 255.
 TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
   std::ostringstream spirv;
   spirv << kGLSL450MemoryModel << opAccessChainSpirvSetup;
-  spirv << "%entry = " << instr << " %_ptr_Private_float %my_matrix";
+  spirv << "%entry = " << instr << " %_ptr_Private_float %my_matrix" << elem;
   for (int i = 0; i < 256; ++i) {
     spirv << " %int_0";
   }
@@ -2076,9 +2094,10 @@ TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesBad) {
 // integer).
 TEST_P(AccessChainInstructionTest, AccessChainUndefinedIndexBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_float %my_matrix %float %int_1
+                 R"( %_ptr_Private_float %my_matrix )" + elem + R"(%float %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -2093,8 +2112,10 @@ OpFunctionEnd
 // OpConstant.
 TEST_P(AccessChainInstructionTest, AccessChainStructIndexNotConstantBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
-%f = )" + instr + R"( %_ptr_Uniform_float %blockName_var %int_0 %spec_int %int_2
+%f = )" + instr + R"( %_ptr_Uniform_float %blockName_var )" +
+                 elem + R"(%int_0 %spec_int %int_2
 OpReturn
 OpFunctionEnd
   )";
@@ -2110,9 +2131,11 @@ OpFunctionEnd
 TEST_P(AccessChainInstructionTest,
        AccessChainStructResultTypeDoesntMatchIndexedTypeBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Uniform_float %blockName_var %int_0 %int_1 %int_2
+                 R"( %_ptr_Uniform_float %blockName_var )" + elem +
+                 R"(%int_0 %int_1 %int_2
 OpReturn
 OpFunctionEnd
   )";
@@ -2128,9 +2151,11 @@ OpFunctionEnd
 // Invalid: Reach non-composite type (bool) when unused indexes remain.
 TEST_P(AccessChainInstructionTest, AccessChainStructTooManyIndexesBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Uniform_float %blockName_var %int_0 %int_2 %int_2
+                 R"( %_ptr_Uniform_float %blockName_var )" + elem +
+                 R"(%int_0 %int_2 %int_2
 OpReturn
 OpFunctionEnd
   )";
@@ -2145,9 +2170,11 @@ OpFunctionEnd
 // Invalid: Trying to find index 3 of the struct that has only 3 members.
 TEST_P(AccessChainInstructionTest, AccessChainStructIndexOutOfBoundBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Uniform_float %blockName_var %int_3 %int_2 %int_2
+                 R"( %_ptr_Uniform_float %blockName_var )" + elem +
+                 R"(%int_3 %int_2 %int_2
 OpReturn
 OpFunctionEnd
   )";
@@ -2169,33 +2196,36 @@ TEST_P(AccessChainInstructionTest, AccessChainIndexIntoAllTypesGood) {
   // 2 will select the column (which is a vector) within the matrix at index 2
   // 0 will select the element at the index 0 of the vector. (which is a float).
   const std::string instr = GetParam();
-  string spirv =
-      kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
-%ss = )" +
-      instr + R"( %_ptr_Uniform_struct_s %blockName_var %int_0
-%sa = )" +
-      instr + R"( %_ptr_Uniform_array5_mat4x3 %blockName_var %int_0 %int_3
-%sm = )" +
-      instr + R"( %_ptr_Uniform_mat4x3 %blockName_var %int_0 %int_3 %int_1
-%sc = )" +
-      instr +
-      R"( %_ptr_Uniform_v3float %blockName_var %int_0 %int_3 %int_1 %int_2
-%entry = )" +
-      instr +
-      R"( %_ptr_Uniform_float %blockName_var %int_0 %int_3 %int_1 %int_2 %int_0
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
+  ostringstream spirv;
+  spirv << kGLSL450MemoryModel << opAccessChainSpirvSetup << std::endl;
+  spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var "
+        << elem << "%int_0" << std::endl;
+  spirv << "%sa = " << instr << " %_ptr_Uniform_array5_mat4x3 %blockName_var "
+        << elem << "%int_0 %int_3" << std::endl;
+  spirv << "%sm = " << instr << " %_ptr_Uniform_mat4x3 %blockName_var " << elem
+        << "%int_0 %int_3 %int_1" << std::endl;
+  spirv << "%sc = " << instr << " %_ptr_Uniform_v3float %blockName_var " << elem
+        << "%int_0 %int_3 %int_1 %int_2" << std::endl;
+  spirv << "%entry = " << instr << " %_ptr_Uniform_float %blockName_var "
+        << elem << "%int_0 %int_3 %int_1 %int_2 %int_0" << std::endl;
+  spirv << R"(
 OpReturn
 OpFunctionEnd
   )";
-  CompileSuccessfully(spirv);
+  CompileSuccessfully(spirv.str());
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
 // Valid: Access an element of OpTypeRuntimeArray.
 TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) {
+  const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %runtime_arr_entry = )" +
-                 GetParam() +
-                 R"( %_ptr_Uniform_float %blockName_var %int_2 %int_0
+                 instr +
+                 R"( %_ptr_Uniform_float %blockName_var )" + elem +
+                 R"(%int_2 %int_0
 OpReturn
 OpFunctionEnd
   )";
@@ -2206,10 +2236,12 @@ OpFunctionEnd
 // Invalid: Unused index when accessing OpTypeRuntimeArray.
 TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %runtime_arr_entry = )" +
                  instr +
-                 R"( %_ptr_Uniform_float %blockName_var %int_2 %int_0 %int_1
+                 R"( %_ptr_Uniform_float %blockName_var )" + elem +
+                 R"(%int_2 %int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -2225,9 +2257,11 @@ OpFunctionEnd
 // finished.
 TEST_P(AccessChainInstructionTest, AccessChainMatrixMoreArgsThanNeededBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_float %my_matrix %int_0 %int_1 %int_0
+                 R"( %_ptr_Private_float %my_matrix )" + elem +
+                 R"(%int_0 %int_1 %int_0
 OpReturn
 OpFunctionEnd
   )";
@@ -2243,9 +2277,11 @@ OpFunctionEnd
 TEST_P(AccessChainInstructionTest,
        AccessChainResultTypeDoesntMatchIndexedTypeBad) {
   const std::string instr = GetParam();
+  const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
   string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
 %entry = )" + instr +
-                 R"( %_ptr_Private_mat4x3 %my_matrix %int_0 %int_1
+                 R"( %_ptr_Private_mat4x3 %my_matrix )" + elem +
+                 R"(%int_0 %int_1
 OpReturn
 OpFunctionEnd
   )";
@@ -2259,10 +2295,10 @@ OpFunctionEnd
 }
 
 // Run tests for Access Chain Instructions.
-INSTANTIATE_TEST_CASE_P(CheckAccessChainInstructions,
-                        AccessChainInstructionTest,
-                        ::testing::Values("OpAccessChain",
-                                          "OpInBoundsAccessChain"));
+INSTANTIATE_TEST_CASE_P(
+    CheckAccessChainInstructions, AccessChainInstructionTest,
+    ::testing::Values("OpAccessChain", "OpInBoundsAccessChain",
+                      "OpPtrAccessChain", "OpInBoundsPtrAccessChain"));
 
 // TODO: OpArrayLength
 // TODO: OpImagePointer