Make IsTypeNullable a transitive check.
authorDejan Mircevski <deki@google.com>
Wed, 3 Feb 2016 22:16:27 +0000 (17:16 -0500)
committerDejan Mircevski <deki@google.com>
Thu, 4 Feb 2016 20:34:47 +0000 (15:34 -0500)
source/opcode.cpp
source/opcode.h
source/validate_id.cpp
test/ValidateID.cpp

index 993001b..ab627ce 100644 (file)
@@ -489,22 +489,6 @@ int32_t spvOpcodeIsPointer(const SpvOp opcode) {
   }
 }
 
-int32_t spvOpcodeIsBasicTypeNullable(SpvOp opcode) {
-  switch (opcode) {
-    case SpvOpTypeBool:
-    case SpvOpTypeInt:
-    case SpvOpTypeFloat:
-    case SpvOpTypePointer:
-    case SpvOpTypeEvent:
-    case SpvOpTypeDeviceEvent:
-    case SpvOpTypeReserveId:
-    case SpvOpTypeQueue:
-      return true;
-    default:
-      return false;
-  }
-}
-
 int32_t spvInstructionIsInBasicBlock(const spv_instruction_t* pFirstInst,
                                      const spv_instruction_t* pInst) {
   while (pFirstInst != pInst) {
index ad0f94d..8fe9e1f 100644 (file)
@@ -92,10 +92,6 @@ int32_t spvOpcodeAreTypesEqual(const spv_instruction_t* type_inst0,
 // non-zero otherwise.
 int32_t spvOpcodeIsPointer(const SpvOp opcode);
 
-// Determines if the scalar type opcode is nullable. Returns zero if false,
-// non-zero otherwise.
-int32_t spvOpcodeIsBasicTypeNullable(SpvOp opcode);
-
 // Determines if an instruction is in a basic block. The first_inst parameter
 // specifies the first instruction in the stream, while the inst parameter
 // specifies the current instruction. Returns zero if false, non-zero otherwise.
index 160808c..f5118a8 100644 (file)
@@ -609,64 +609,54 @@ bool idUsage::isValid<SpvOpConstantSampler>(const spv_instruction_t* inst,
   return true;
 }
 
+// True if instruction defines a type that can have a null value, as defined by
+// the SPIR-V spec.  Tracks composite-type components through usedefs to check
+// nullability transitively.
+bool IsTypeNullable(const std::vector<uint32_t>& instruction,
+                    const UseDefTracker& usedefs) {
+  SpvOp opcode;
+  uint16_t word_count;
+  spvOpcodeSplit(instruction[0], &word_count, &opcode);
+  switch (opcode) {
+    case SpvOpTypeBool:
+    case SpvOpTypeInt:
+    case SpvOpTypeFloat:
+    case SpvOpTypePointer:
+    case SpvOpTypeEvent:
+    case SpvOpTypeDeviceEvent:
+    case SpvOpTypeReserveId:
+    case SpvOpTypeQueue:
+      return true;
+    case SpvOpTypeArray:
+    case SpvOpTypeMatrix:
+    case SpvOpTypeVector: {
+      auto base_type = usedefs.FindDef(instruction[2]);
+      return base_type.first && IsTypeNullable(base_type.second.words, usedefs);
+    }
+    case SpvOpTypeStruct: {
+      for (size_t elementIndex = 2; elementIndex < instruction.size();
+           ++elementIndex) {
+        auto element = usedefs.FindDef(instruction[elementIndex]);
+        if (!element.first || !IsTypeNullable(element.second.words, usedefs))
+          return false;
+      }
+      return true;
+    }
+    default:
+      return false;
+  }
+}
+
 template <>
 bool idUsage::isValid<SpvOpConstantNull>(const spv_instruction_t* inst,
                                          const spv_opcode_desc) {
   auto resultTypeIndex = 1;
   auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]);
-  if (!resultType.first) return false;
-  switch (resultType.second.opcode) {
-    default: {
-      spvCheck(!spvOpcodeIsBasicTypeNullable(resultType.second.opcode),
-               DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '"
-                                     << inst->words[resultTypeIndex]
-                                     << "' cannot be null.";
-               return false);
-    } break;
-    case SpvOpTypeVector: {
-      auto type = usedefs_.FindDef(resultType.second.words[2]);
-      assert(type.first);
-      spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode),
-               DIAG(resultTypeIndex)
-                   << "OpConstantNull Result Type <id> '"
-                   << inst->words[resultTypeIndex]
-                   << "'s vector component type cannot be null.";
-               return false);
-    } break;
-    case SpvOpTypeArray: {
-      auto type = usedefs_.FindDef(resultType.second.words[2]);
-      assert(type.first);
-      spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode),
-               DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '"
-                                     << inst->words[resultTypeIndex]
-                                     << "'s array element type cannot be null.";
-               return false);
-    } break;
-    case SpvOpTypeMatrix: {
-      auto columnType = usedefs_.FindDef(resultType.second.words[2]);
-      assert(columnType.first);
-      auto type = usedefs_.FindDef(columnType.second.words[2]);
-      assert(type.first);
-      spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode),
-               DIAG(resultTypeIndex)
-                   << "OpConstantNull Result Type <id> '"
-                   << inst->words[resultTypeIndex]
-                   << "'s matrix component type cna not be null.";
-               return false);
-    } break;
-    case SpvOpTypeStruct: {
-      for (size_t elementIndex = 2;
-           elementIndex < resultType.second.words.size(); ++elementIndex) {
-        auto element = usedefs_.FindDef(resultType.second.words[elementIndex]);
-        assert(element.first);
-        spvCheck(!spvOpcodeIsBasicTypeNullable(element.second.opcode),
-                 DIAG(resultTypeIndex)
-                     << "OpConstantNull Result Type <id> '"
-                     << inst->words[resultTypeIndex]
-                     << "'s struct element type cannot be null.";
-                 return false);
-      }
-    } break;
+  if (!resultType.first || !IsTypeNullable(resultType.second.words, usedefs_)) {
+    DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '"
+                          << inst->words[resultTypeIndex]
+                          << "' cannot have a null value.";
+    return false;
   }
   return true;
 }
index 90778cc..d2d6a26 100644 (file)
@@ -612,15 +612,21 @@ TEST_F(ValidateID, OpConstantNullGood) {
 %22 = OpConstantNull %21
 %23 = OpTypeStruct %3 %5 %1
 %24 = OpConstantNull %23
+%26 = OpTypeArray %17 %25
+%27 = OpConstantNull %26
+%28 = OpTypeStruct %7 %26 %26 %1
+%29 = OpConstantNull %28
 )";
   CHECK(spirv, SPV_SUCCESS);
 }
+
 TEST_F(ValidateID, OpConstantNullBasicBad) {
   const char* spirv = R"(
 %1 = OpTypeVoid
 %2 = OpConstantNull %1)";
   CHECK(spirv, SPV_ERROR_INVALID_ID);
 }
+
 TEST_F(ValidateID, OpConstantNullArrayBad) {
   const char* spirv = R"(
 %2 = OpTypeInt 32 0
@@ -630,6 +636,7 @@ TEST_F(ValidateID, OpConstantNullArrayBad) {
 %6 = OpConstantNull %5)";
   CHECK(spirv, SPV_ERROR_INVALID_ID);
 }
+
 TEST_F(ValidateID, OpConstantNullStructBad) {
   const char* spirv = R"(
 %2 = OpTypeSampler
@@ -638,6 +645,14 @@ TEST_F(ValidateID, OpConstantNullStructBad) {
   CHECK(spirv, SPV_ERROR_INVALID_ID);
 }
 
+TEST_F(ValidateID, OpConstantNullRuntimeArrayBad) {
+  const char* spirv = R"(
+%bool = OpTypeBool
+%array = OpTypeRuntimeArray %bool
+%null = OpConstantNull %array)";
+  CHECK(spirv, SPV_ERROR_INVALID_ID);
+}
+
 TEST_F(ValidateID, OpSpecConstantTrueGood) {
   const char* spirv = R"(
 %1 = OpTypeBool