bitcast fixes
authorFlorian Ziesche <florian.ziesche@gmail.com>
Wed, 2 Mar 2016 21:17:54 +0000 (22:17 +0100)
committerDejan Mircevski <deki@google.com>
Fri, 29 Apr 2016 18:55:05 +0000 (14:55 -0400)
 * ValidationState_t and idUsage now store the addressing model and memory model of the SPIR-V module (this is necessary for certain instructions that need different checks depending on if the logical or physical addressing model is used)
 * removed SpvOpPtrAccessChain and SpvOpInBoundsPtrAccessChain from spvOpcodeIsPointer again as these are disallowed in logical addressing mode and only allowed in physical addressing mode (which doesn't use/need spvOpcodeIsPointer in the first place)
 * added SpvOpImageTexelPointer and SpvOpCopyObject to spvOpcodeIsPointer
 * OpLoad/OpStore now only check if the used pointer operand originated from a valid pointer producing opcode in logical addressing mode (as per 2.16.1)
 * moved bitcast pointer tests to the kernel / physical addressing model part (+cleanup)
 * renamed spvOpcodeIsPointer to spvOpcodeReturnsLogicalPointer to clarify this function is only meant to be used with the logical addressing model

source/opcode.cpp
source/opcode.h
source/validate.h
source/validate_id.cpp
source/validate_instruction.cpp
source/validate_types.cpp
test/ValidateID.cpp

index c34d626..fd0dda6 100644 (file)
@@ -226,14 +226,14 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode) {
   }
 }
 
-int32_t spvOpcodeIsPointer(const SpvOp opcode) {
+int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode) {
   switch (opcode) {
     case SpvOpVariable:
     case SpvOpAccessChain:
-    case SpvOpPtrAccessChain:
     case SpvOpInBoundsAccessChain:
-    case SpvOpInBoundsPtrAccessChain:
     case SpvOpFunctionParameter:
+    case SpvOpImageTexelPointer:
+    case SpvOpCopyObject:
       return true;
     default:
       return false;
index 737af8c..6609824 100644 (file)
@@ -85,9 +85,9 @@ int32_t spvOpcodeIsConstant(const SpvOp opcode);
 // non-zero otherwise.
 int32_t spvOpcodeIsComposite(const SpvOp opcode);
 
-// Determines if the given opcode results in a pointer. Returns zero if false,
-// non-zero otherwise.
-int32_t spvOpcodeIsPointer(const SpvOp opcode);
+// Determines if the given opcode results in a pointer when using the logical
+// addressing model. Returns zero if false, non-zero otherwise.
+int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode);
 
 // Determines if the given opcode generates a type. Returns zero if false,
 // non-zero otherwise.
index 972916f..ae0abb3 100644 (file)
@@ -272,6 +272,18 @@ class ValidationState_t {
   // capabilities==0.
   bool HasAnyOf(spv_capability_mask_t capabilities) const;
 
+  // Sets the addressing model of this module (logical/physical).
+  void setAddressingModel(SpvAddressingModel am);
+
+  // Returns the addressing model of this module, or Logical if uninitialized.
+  SpvAddressingModel getAddressingModel() const;
+
+  // Sets the memory model of this module.
+  void setMemoryModel(SpvMemoryModel mm);
+
+  // Returns the memory model of this module, or Simple if uninitialized.
+  SpvMemoryModel getMemoryModel() const;
+
   AssemblyGrammar& grammar() { return grammar_; }
 
  private:
@@ -298,6 +310,10 @@ class ValidationState_t {
   std::vector<uint32_t> entry_points_;
 
   AssemblyGrammar grammar_;
+
+  SpvAddressingModel addressing_model_;
+  SpvMemoryModel memory_model_;
+
 };
 
 }  // namespace libspirv
index 879af59..7a642cb 100644 (file)
@@ -50,6 +50,8 @@ class idUsage {
           const spv_operand_table operandTableArg,
           const spv_ext_inst_table extInstTableArg,
           const spv_instruction_t* pInsts, const uint64_t instCountArg,
+          const SpvMemoryModel memoryModelArg,
+          const SpvAddressingModel addressingModelArg,
           const UseDefTracker& usedefs,
           const std::vector<uint32_t>& entry_points, spv_position positionArg,
           spv_diagnostic* pDiagnosticArg)
@@ -58,6 +60,8 @@ class idUsage {
         extInstTable(extInstTableArg),
         firstInst(pInsts),
         instCount(instCountArg),
+        memoryModel(memoryModelArg),
+        addressingModel(addressingModelArg),
         position(positionArg),
         pDiagnostic(pDiagnosticArg),
         usedefs_(usedefs),
@@ -74,6 +78,8 @@ class idUsage {
   const spv_ext_inst_table extInstTable;
   const spv_instruction_t* const firstInst;
   const uint64_t instCount;
+  const SpvMemoryModel memoryModel;
+  const SpvAddressingModel addressingModel;
   spv_position position;
   spv_diagnostic* pDiagnostic;
   UseDefTracker usedefs_;
@@ -755,7 +761,9 @@ bool idUsage::isValid<SpvOpLoad>(const spv_instruction_t* inst,
            return false);
   auto pointerIndex = 3;
   auto pointer = usedefs_.FindDef(inst->words[pointerIndex]);
-  if (!pointer.first || !spvOpcodeIsPointer(pointer.second.opcode)) {
+  if (!pointer.first ||
+      (addressingModel == SpvAddressingModelLogical &&
+       !spvOpcodeReturnsLogicalPointer(pointer.second.opcode))) {
     DIAG(pointerIndex) << "OpLoad Pointer <id> '" << inst->words[pointerIndex]
                        << "' is not a pointer.";
     return false;
@@ -783,13 +791,20 @@ bool idUsage::isValid<SpvOpStore>(const spv_instruction_t* inst,
                                   const spv_opcode_desc) {
   auto pointerIndex = 1;
   auto pointer = usedefs_.FindDef(inst->words[pointerIndex]);
-  if (!pointer.first || !spvOpcodeIsPointer(pointer.second.opcode)) {
+  if (!pointer.first ||
+      (addressingModel == SpvAddressingModelLogical &&
+       !spvOpcodeReturnsLogicalPointer(pointer.second.opcode))) {
     DIAG(pointerIndex) << "OpStore Pointer <id> '" << inst->words[pointerIndex]
                        << "' is not a pointer.";
     return false;
   }
   auto pointerType = usedefs_.FindDef(pointer.second.type_id);
-  assert(pointerType.first);
+  if (!pointer.first || pointerType.second.opcode != SpvOpTypePointer) {
+    DIAG(pointerIndex) << "OpStore type for pointer <id> '"
+                       << inst->words[pointerIndex]
+                       << "' is not a pointer type.";
+    return false;
+  }
   auto type = usedefs_.FindDef(pointerType.second.words[3]);
   assert(type.first);
   spvCheck(SpvOpTypeVoid == type.second.opcode, DIAG(pointerIndex)
@@ -2332,6 +2347,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts,
                                        spv_position position,
                                        spv_diagnostic* pDiag) {
   idUsage idUsage(opcodeTable, operandTable, extInstTable, pInsts, instCount,
+                  state.getMemoryModel(), state.getAddressingModel(),
                   state.usedefs(), state.entry_points(), position, pDiag);
   for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) {
     spvCheck(!idUsage.isValid(&pInsts[instIndex]), return SPV_ERROR_INVALID_ID);
index 8ecb715..26053e3 100644 (file)
@@ -122,6 +122,12 @@ spv_result_t InstructionPass(ValidationState_t& _,
   if (opcode == SpvOpCapability)
     _.registerCapability(
         static_cast<SpvCapability>(inst->words[inst->operands[0].offset]));
+  if (opcode == SpvOpMemoryModel) {
+    _.setAddressingModel(
+        static_cast<SpvAddressingModel>(inst->words[inst->operands[0].offset]));
+    _.setMemoryModel(
+        static_cast<SpvMemoryModel>(inst->words[inst->operands[1].offset]));
+  }
   if (opcode == SpvOpVariable) {
     const auto storage_class =
         static_cast<SpvStorageClass>(inst->words[inst->operands[2].offset]);
index 894e530..e36454a 100644 (file)
@@ -218,7 +218,9 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
       current_layout_section_(kLayoutCapabilities),
       module_functions_(*this),
       module_capabilities_(0u),
-      grammar_(context) {}
+      grammar_(context),
+      addressing_model_(SpvAddressingModelLogical),
+      memory_model_(SpvMemoryModelSimple) {}
 
 spv_result_t ValidationState_t::forwardDeclareId(uint32_t id) {
   unresolved_forward_ids_.insert(id);
@@ -316,6 +318,22 @@ bool ValidationState_t::HasAnyOf(spv_capability_mask_t capabilities) const {
   });
   return found;
 }
+       
+void ValidationState_t::setAddressingModel(SpvAddressingModel am) {
+  addressing_model_ = am;
+}
+
+SpvAddressingModel ValidationState_t::getAddressingModel() const {
+  return addressing_model_;
+}
+
+void ValidationState_t::setMemoryModel(SpvMemoryModel mm) {
+  memory_model_ = mm;
+}
+
+SpvMemoryModel ValidationState_t::getMemoryModel() const {
+  return memory_model_;
+}
 
 Functions::Functions(ValidationState_t& module)
     : module_(module), in_function_(false), in_block_(false) {}
index 4b0da84..6cff0ab 100644 (file)
@@ -1694,6 +1694,76 @@ TEST_F(ValidateID, OpPtrAccessChainGood) {
   CHECK_KERNEL(spirv, SPV_SUCCESS, 64);
 }
 
+TEST_F(ValidateID, OpLoadBitcastPointerGood) {
+  const char* spirv = R"(
+%2  = OpTypeVoid
+%3  = OpTypeInt 32 1
+%4  = OpTypeFloat 32
+%5  = OpTypePointer UniformConstant %3
+%6  = OpTypePointer UniformConstant %4
+%7  = OpVariable %5 UniformConstant
+%8  = OpTypeFunction %2
+%9  = OpFunction %2 None %8
+%10 = OpLabel
+%11 = OpBitcast %6 %7
+%12 = OpLoad %4 %11
+      OpReturn
+      OpFunctionEnd)";
+  CHECK_KERNEL(spirv, SPV_SUCCESS, 64);
+}
+TEST_F(ValidateID, OpLoadBitcastNonPointerBad) {
+  const char* spirv = R"(
+%2  = OpTypeVoid
+%3  = OpTypeInt 32 1
+%4  = OpTypeFloat 32
+%5  = OpTypePointer UniformConstant %3
+%6  = OpTypeFunction %2
+%7  = OpVariable %5 UniformConstant
+%8  = OpFunction %2 None %6
+%9  = OpLabel
+%10 = OpLoad %3 %7
+%11 = OpBitcast %4 %10
+%12 = OpLoad %3 %11
+      OpReturn
+      OpFunctionEnd)";
+  CHECK_KERNEL(spirv, SPV_ERROR_INVALID_ID, 64);
+}
+TEST_F(ValidateID, OpStoreBitcastPointerGood) {
+  const char* spirv = R"(
+%2  = OpTypeVoid
+%3  = OpTypeInt 32 1
+%4  = OpTypeFloat 32
+%5  = OpTypePointer Function %3
+%6  = OpTypePointer Function %4
+%7  = OpTypeFunction %2
+%8  = OpConstant %3 42
+%9  = OpFunction %2 None %7
+%10 = OpLabel
+%11 = OpVariable %6 Function
+%12 = OpBitcast %5 %11
+      OpStore %12 %8
+      OpReturn
+      OpFunctionEnd)";
+  CHECK_KERNEL(spirv, SPV_SUCCESS, 64);
+}
+TEST_F(ValidateID, OpStoreBitcastNonPointerBad) {
+  const char* spirv = R"(
+%2  = OpTypeVoid
+%3  = OpTypeInt 32 1
+%4  = OpTypeFloat 32
+%5  = OpTypePointer Function %4
+%6  = OpTypeFunction %2
+%7  = OpConstant %4 42
+%8  = OpFunction %2 None %6
+%9  = OpLabel
+%10 = OpVariable %5 Function
+%11 = OpBitcast %3 %7
+      OpStore %11 %7
+      OpReturn
+      OpFunctionEnd)";
+  CHECK_KERNEL(spirv, SPV_ERROR_INVALID_ID, 64);
+}
+
 // TODO: OpLifetimeStart
 // TODO: OpLifetimeStop
 // TODO: OpAtomicInit