Adding validation code for OpAccessChain.
authorEhsan Nasiri <ehsann@google.com>
Wed, 30 Nov 2016 18:29:12 +0000 (13:29 -0500)
committerDavid Neto <dneto@google.com>
Fri, 2 Dec 2016 18:50:41 +0000 (13:50 -0500)
* Result Type must be an OpTypePointer. Its Type operand must be the
type reached by walking the Base’s type hierarchy down to the last
provided index in Indexes, and its Storage Class operand must be the
same as the Storage Class of Base.

* Base must be a pointer, pointing to the base of a composite object.

* Indexes walk the type hierarchy to the desired depth, potentially down
to scalar granularity. The first index in Indexes will select the
top-level member/element/component/element of the base composite. All
composite constituents use zero-based numbering, as described by their
OpType... instruction. The second index will apply similarly to that
result, and so on. Once any non-composite type is reached, there must
be no remaining (unused) indexes. Each of the Indexes must:
- be a scalar integer type,
- be an OpConstant when indexing into a structure.

* Check for the case where no indexes are passed to OpAccessChain.

Minor improvements based on code review.

source/validate_id.cpp
test/val/val_id_test.cpp

index 9e4d942..94b0753 100644 (file)
@@ -1184,11 +1184,145 @@ bool idUsage::isValid<SpvOpCopyMemorySized>(const spv_instruction_t* inst,
   return true;
 }
 
-#if 0
 template <>
-bool idUsage::isValid<SpvOpAccessChain>(const spv_instruction_t *inst,
-                                        const spv_opcode_desc opcodeEntry) {}
-#endif
+bool idUsage::isValid<SpvOpAccessChain>(const spv_instruction_t* inst,
+                                        const spv_opcode_desc) {
+  // The result type must be OpTypePointer. Result Type is at word 1.
+  auto resultTypeIndex = 1;
+  auto resultTypeInstr = module_.FindDef(inst->words[resultTypeIndex]);
+  if (SpvOpTypePointer != resultTypeInstr->opcode()) {
+    DIAG(resultTypeIndex) << "The Result Type of OpAccessChain <id> '"
+                          << inst->words[2]
+                          << "' must be OpTypePointer. Found Op"
+                          << spvOpcodeString(
+                                 static_cast<SpvOp>(resultTypeInstr->opcode()))
+                          << ".";
+    return false;
+  }
+
+  // Result type is a pointer. Find out what it's pointing to.
+  // This will be used to make sure the indexing results in the same type.
+  // OpTypePointer word 3 is the type being pointed to.
+  auto resultTypePointedTo = module_.FindDef(resultTypeInstr->word(3));
+
+  // Base must be a pointer, pointing to the base of a composite object.
+  auto baseIdIndex = 3;
+  auto baseInstr = module_.FindDef(inst->words[baseIdIndex]);
+  auto baseTypeInstr = module_.FindDef(baseInstr->type_id());
+  if (!baseTypeInstr || SpvOpTypePointer != baseTypeInstr->opcode()) {
+    DIAG(baseIdIndex) << "The Base <id> '" << inst->words[baseIdIndex]
+                      << "' in OpAccessChain instruction must be a pointer.";
+    return false;
+  }
+
+  // The result pointer storage class and base pointer storage class must match.
+  // Word 2 of OpTypePointer is the Storage Class.
+  auto resultTypeStorageClass = resultTypeInstr->word(2);
+  auto baseTypeStorageClass = baseTypeInstr->word(2);
+  if (resultTypeStorageClass != baseTypeStorageClass) {
+    DIAG(resultTypeIndex) << "The result pointer storage class and base "
+                             "pointer storage class in OpAccessChain do not "
+                             "match.";
+    return false;
+  }
+
+  // The type pointed to by OpTypePointer (word 3) must be a composite type.
+  auto typePointedTo = module_.FindDef(baseTypeInstr->word(3));
+
+  // Check Universal Limit (SPIR-V Spec. Section 2.17).
+  // The number of indexes passed to OpAccessChain may not exceed 255
+  // The instruction includes 4 words + N words (for N indexes)
+  const size_t num_indexes = inst->words.size() - 4;
+  const size_t num_indexes_limit = 255;
+  if (num_indexes > num_indexes_limit) {
+    DIAG(resultTypeIndex)
+        << "The number of indexes in OpAccessChain may not exceed "
+        << num_indexes_limit << ". Found " << num_indexes << " indexes.";
+    return false;
+  }
+  if (num_indexes <= 0) {
+    DIAG(resultTypeIndex) << "No Indexes were passes to OpAccessChain.";
+    return false;
+  }
+  // Indexes walk the type hierarchy to the desired depth, potentially down to
+  // scalar granularity. The first index in Indexes will select the top-level
+  // member/element/component/element of the base composite. All composite
+  // constituents use zero-based numbering, as described by their OpType...
+  // instruction. The second index will apply similarly to that result, and so
+  // on. Once any non-composite type is reached, there must be no remaining
+  // (unused) indexes.
+  for (size_t i = 4; i < inst->words.size(); ++i) {
+    const uint32_t cur_word = inst->words[i];
+    // Earlier ID checks ensure that cur_word definition exists.
+    auto cur_word_instr = module_.FindDef(cur_word);
+    // The index must be a scalar integer type (See OpAccessChain in the Spec.)
+    auto indexTypeInstr = module_.FindDef(cur_word_instr->type_id());
+    if (!indexTypeInstr || SpvOpTypeInt != indexTypeInstr->opcode()) {
+      DIAG(i) << "Indexes passed to OpAccessChain must be of type integer.";
+      return false;
+    }
+    switch (typePointedTo->opcode()) {
+      case SpvOpTypeMatrix:
+      case SpvOpTypeVector:
+      case SpvOpTypeArray: {
+        // In OpTypeMatrix, OpTypeArray, and OpTypeVector, word 2 is the
+        // Element Type.
+        typePointedTo = module_.FindDef(typePointedTo->word(2));
+        break;
+      }
+      case SpvOpTypeStruct: {
+        // In case of structures, there is an additional constraint on the
+        // index: the index must be an OpConstant.
+        if (SpvOpConstant != cur_word_instr->opcode()) {
+          DIAG(i) << "The <id> passed to OpAccessChain to index into a "
+                     "structure must be an OpConstant.";
+          return false;
+        }
+        // Get the index value from the OpConstant (word 3 of OpConstant).
+        // OpConstant could be a signed integer. But it's okay to treat it as
+        // unsigned because a negative constant int would never be seen as
+        // correct as a struct offset, since structs can't have more than 2
+        // billion members.
+        const uint32_t cur_index = cur_word_instr->word(3);
+        // The index points to the struct member we want, therefore, the index
+        // should be less than the number of struct members.
+        const uint32_t num_struct_members =
+            static_cast<uint32_t>(typePointedTo->words().size() - 2);
+        if (cur_index >= num_struct_members) {
+          DIAG(i) << "Index is out of bound: OpAccessChain can not find index "
+                  << cur_index << " into the structure <id> '"
+                  << typePointedTo->id() << "'. This structure has "
+                  << num_struct_members << " members. Largest valid index is "
+                  << num_struct_members - 1 << ".";
+          return false;
+        }
+        // Struct members IDs start at word 2 of OpTypeStruct.
+        auto structMemberId = typePointedTo->word(cur_index + 2);
+        typePointedTo = module_.FindDef(structMemberId);
+        break;
+      }
+      default: {
+        // Give an error. reached non-composite type while indexes still remain.
+        DIAG(i) << "OpAccessChain reached non-composite type while indexes "
+                   "still remain to be traversed.";
+        return false;
+      }
+    }
+  }
+  // At this point, we have fully walked down from the base using the indeces.
+  // The type being pointed to should be the same as the result type.
+  if (typePointedTo->id() != resultTypePointedTo->id()) {
+    DIAG(resultTypeIndex)
+        << "OpAccessChain result type (Op"
+        << spvOpcodeString(static_cast<SpvOp>(resultTypePointedTo->opcode()))
+        << ") does not match the type that results from indexing into the base "
+           "<id> (Op"
+        << spvOpcodeString(static_cast<SpvOp>(typePointedTo->opcode())) << ").";
+    return false;
+  }
+
+  return true;
+}
 
 #if 0
 template <>
@@ -2422,7 +2556,7 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
     CASE(OpStore)
     CASE(OpCopyMemory)
     CASE(OpCopyMemorySized)
-    TODO(OpAccessChain)
+    CASE(OpAccessChain)
     TODO(OpInBoundsAccessChain)
     TODO(OpArrayLength)
     TODO(OpGenericPtrMemSemantics)
index 5dafa08..b583f1a 100644 (file)
@@ -1785,7 +1785,297 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeTypeBad) {
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
 }
 
-// TODO: OpAccessChain
+string opAccessChainSpirvSetup = R"(
+%void = OpTypeVoid
+%void_f  = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%float = OpTypeFloat 32
+%v3float = OpTypeVector %float 3
+%mat4x3 = OpTypeMatrix %v3float 4
+%_ptr_Private_mat4x3 = OpTypePointer Private %mat4x3
+%_ptr_Private_float = OpTypePointer Private %float
+%my_matrix = OpVariable %_ptr_Private_mat4x3 Private
+%my_float_var = OpVariable %_ptr_Private_float Private
+%_ptr_Function_float = OpTypePointer Function %float
+%int_0 = OpConstant %int 0
+%int_1 = OpConstant %int 1
+%int_2 = OpConstant %int 2
+%int_3 = OpConstant %int 3
+%int_5 = OpConstant %int 5
+
+; Let's make the following structures to test OpAccessChain
+;
+; struct S {
+;   bool b;
+;   vec4 v[5];
+;   int i;
+;   mat4x3 m[5];
+; }
+; uniform blockName {
+;   S s;
+;   bool cond;
+; }
+
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%array5_mat4x3 = OpTypeArray %mat4x3 %int_5
+%array5_vec4 = OpTypeArray %v4float %int_5
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+%_ptr_Function_vec4 = OpTypePointer Function %v4float
+%_ptr_Uniform_vec4 = OpTypePointer Uniform %v4float
+%struct_s = OpTypeStruct %bool %array5_vec4 %int %array5_mat4x3
+%struct_blockName = OpTypeStruct %struct_s %bool
+%_ptr_Uniform_blockName = OpTypePointer Uniform %struct_blockName
+%_ptr_Uniform_struct_s = OpTypePointer Uniform %struct_s
+%_ptr_Uniform_array5_mat4x3 = OpTypePointer Uniform %array5_mat4x3
+%_ptr_Uniform_mat4x3 = OpTypePointer Uniform %mat4x3
+%_ptr_Uniform_v3float = OpTypePointer Uniform %v3float
+%blockName_var = OpVariable %_ptr_Uniform_blockName Uniform
+%spec_int = OpSpecConstant %int 2
+%func = OpFunction %void None %void_f
+%my_label = OpLabel
+)";
+
+// Valid: Access a float in a matrix using OpAccessChain
+TEST_F(ValidateIdWithMessage, OpAccessChainGood) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%float_entry = OpAccessChain %_ptr_Private_float %my_matrix %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Invalid. The result type of OpAccessChain must be a pointer.
+TEST_F(ValidateIdWithMessage, OpAccessChainResultTypeBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%float_entry = OpAccessChain %float %my_matrix %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The Result Type of OpAccessChain <id> '35' must be "
+                        "OpTypePointer. Found OpTypeFloat."));
+}
+
+// Invalid. The base type of OpAccessChain must be a pointer.
+TEST_F(ValidateIdWithMessage, OpAccessChainBaseTypeVoidBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%float_entry = OpAccessChain %_ptr_Private_float %void %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The Base <id> '1' in OpAccessChain instruction must "
+                        "be a pointer."));
+}
+
+// Invalid. The base type of OpAccessChain must be a pointer.
+TEST_F(ValidateIdWithMessage, OpAccessChainBaseTypeNonPtrVariableBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_float %_ptr_Private_float %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The Base <id> '8' in OpAccessChain instruction must "
+                        "be a pointer."));
+}
+
+// Invalid: The storage class of Base and Result do not match.
+TEST_F(ValidateIdWithMessage,
+       OpAccessChainResultAndBaseStorageClassDoesntMatchBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Function_float %my_matrix %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The result pointer storage class and base pointer "
+                        "storage class in OpAccessChain do not match."));
+}
+
+// Invalid. The base type of OpAccessChain must point to a composite object.
+TEST_F(ValidateIdWithMessage, OpAccessChainBasePtrNotPointingToCompositeBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_float %my_float_var %int_0
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpAccessChain reached non-composite type while "
+                        "indexes still remain to be traversed."));
+}
+
+// Invalid. No Indexes passed to OpAccessChain
+TEST_F(ValidateIdWithMessage, OpAccessChainMissingIndexesBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_float %my_float_var
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("No Indexes were passes to OpAccessChain."));
+}
+
+// Invalid: 256 indexes passed to OpAccessChain. Limit is 255.
+TEST_F(ValidateIdWithMessage, OpAccessChainTooManyIndecesBad) {
+  std::ostringstream spirv;
+  spirv << kGLSL450MemoryModel << opAccessChainSpirvSetup;
+  spirv << "%entry = OpAccessChain %_ptr_Private_float %my_matrix";
+  for (int i = 0; i < 256; ++i) {
+    spirv << " %int_0";
+  }
+  spirv << R"(
+    OpReturn
+    OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv.str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The number of indexes in OpAccessChain may not exceed "
+                        "255. Found 256 indexes."));
+}
+
+// Invalid: Index passed to OpAccessChain is float (must be integer).
+TEST_F(ValidateIdWithMessage, OpAccessChainUndefinedIndexBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_float %my_matrix %float %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Indexes passed to OpAccessChain must be of type integer."));
+}
+
+// Invalid: The OpAccessChain index argument that indexes into a struct must be
+// of type OpConstant.
+TEST_F(ValidateIdWithMessage, OpAccessChainStructIndexNotConstantBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%f = OpAccessChain %_ptr_Uniform_float %blockName_var %int_0 %spec_int %int_2
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("The <id> passed to OpAccessChain to index into a "
+                        "structure must be an OpConstant."));
+}
+
+// Invalid: Indexing up to a vec4 granularity, but result type expected float.
+TEST_F(ValidateIdWithMessage,
+       OpAccessChainStructResultTypeDoesntMatchIndexedTypeBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Uniform_float %blockName_var %int_0 %int_1 %int_2
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "OpAccessChain result type (OpTypeFloat) does not match the type "
+          "that results from indexing into the base <id> (OpTypeVector)."));
+}
+
+// Invalid: Reach non-composite type (bool) when unused indexes remain.
+TEST_F(ValidateIdWithMessage, OpAccessChainStructTooManyIndexesBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Uniform_float %blockName_var %int_0 %int_2 %int_2
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpAccessChain reached non-composite type while "
+                        "indexes still remain to be traversed."));
+}
+
+// Invalid: Trying to find index 2 of the struct that has only 2 members.
+TEST_F(ValidateIdWithMessage, OpAccessChainStructIndexOutOfBoundBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Uniform_float %blockName_var %int_2 %int_2 %int_2
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Index is out of bound: OpAccessChain can not find "
+                        "index 2 into the structure <id> '25'. This structure "
+                        "has 2 members. Largest valid index is 1."));
+}
+
+// Valid: Tests that we can index into Struct, Array, Matrix, and Vector!
+TEST_F(ValidateIdWithMessage, OpAccessChainIndexIntoAllTypesGood) {
+  // indexes that we are passing are: 0, 3, 1, 2, 0
+  // 0 will select the struct_s within the base struct (blockName)
+  // 3 will select the Array that contains 5 matrices
+  // 1 will select the Matrix that is at index 1 of the array
+  // 2 will select the column (which is a vector) within the matrix at index 2
+  // 0 will select the element at the index 0 of the vector. (which is a float).
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%ss = OpAccessChain %_ptr_Uniform_struct_s %blockName_var %int_0
+%sa = OpAccessChain %_ptr_Uniform_array5_mat4x3 %blockName_var %int_0 %int_3
+%sm = OpAccessChain %_ptr_Uniform_mat4x3 %blockName_var %int_0 %int_3 %int_1
+%sc = OpAccessChain %_ptr_Uniform_v3float %blockName_var %int_0 %int_3 %int_1 %int_2
+%entry = OpAccessChain %_ptr_Uniform_float %blockName_var %int_0 %int_3 %int_1 %int_2 %int_0
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Invalid: Reached scalar type before arguments to OpAccessChain finished.
+TEST_F(ValidateIdWithMessage, OpAccessChainMatrixMoreArgsThanNeededBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_float %my_matrix %int_0 %int_1 %int_0
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpAccessChain reached non-composite type while "
+                        "indexes still remain to be traversed."));
+}
+
+// Invalid: The result type and the type indexed into do not match.
+TEST_F(ValidateIdWithMessage,
+       OpAccessChainResultTypeDoesntMatchIndexedTypeBad) {
+  string spirv = kGLSL450MemoryModel + opAccessChainSpirvSetup + R"(
+%entry = OpAccessChain %_ptr_Private_mat4x3 %my_matrix %int_0 %int_1
+OpReturn
+OpFunctionEnd
+  )";
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpAccessChain result type (OpTypeMatrix) does not "
+                        "match the type that results from indexing into the "
+                        "base <id> (OpTypeFloat)."));
+}
+
 // TODO: OpInBoundsAccessChain
 // TODO: OpArrayLength
 // TODO: OpImagePointer