From 7ef6da7b7f9175da509b4d71a881c0a04e0b701b Mon Sep 17 00:00:00 2001 From: Dejan Mircevski Date: Wed, 3 Feb 2016 17:16:27 -0500 Subject: [PATCH] Make IsTypeNullable a transitive check. --- source/opcode.cpp | 16 --------- source/opcode.h | 4 --- source/validate_id.cpp | 96 ++++++++++++++++++++++---------------------------- test/ValidateID.cpp | 15 ++++++++ 4 files changed, 58 insertions(+), 73 deletions(-) diff --git a/source/opcode.cpp b/source/opcode.cpp index 993001b..ab627ce 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -489,22 +489,6 @@ int32_t spvOpcodeIsPointer(const SpvOp opcode) { } } -int32_t spvOpcodeIsBasicTypeNullable(SpvOp opcode) { - switch (opcode) { - case SpvOpTypeBool: - case SpvOpTypeInt: - case SpvOpTypeFloat: - case SpvOpTypePointer: - case SpvOpTypeEvent: - case SpvOpTypeDeviceEvent: - case SpvOpTypeReserveId: - case SpvOpTypeQueue: - return true; - default: - return false; - } -} - int32_t spvInstructionIsInBasicBlock(const spv_instruction_t* pFirstInst, const spv_instruction_t* pInst) { while (pFirstInst != pInst) { diff --git a/source/opcode.h b/source/opcode.h index ad0f94d..8fe9e1f 100644 --- a/source/opcode.h +++ b/source/opcode.h @@ -92,10 +92,6 @@ int32_t spvOpcodeAreTypesEqual(const spv_instruction_t* type_inst0, // non-zero otherwise. int32_t spvOpcodeIsPointer(const SpvOp opcode); -// Determines if the scalar type opcode is nullable. Returns zero if false, -// non-zero otherwise. -int32_t spvOpcodeIsBasicTypeNullable(SpvOp opcode); - // Determines if an instruction is in a basic block. The first_inst parameter // specifies the first instruction in the stream, while the inst parameter // specifies the current instruction. Returns zero if false, non-zero otherwise. diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 160808c..f5118a8 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -609,64 +609,54 @@ bool idUsage::isValid(const spv_instruction_t* inst, return true; } +// True if instruction defines a type that can have a null value, as defined by +// the SPIR-V spec. Tracks composite-type components through usedefs to check +// nullability transitively. +bool IsTypeNullable(const std::vector& instruction, + const UseDefTracker& usedefs) { + SpvOp opcode; + uint16_t word_count; + spvOpcodeSplit(instruction[0], &word_count, &opcode); + switch (opcode) { + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypePointer: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + return true; + case SpvOpTypeArray: + case SpvOpTypeMatrix: + case SpvOpTypeVector: { + auto base_type = usedefs.FindDef(instruction[2]); + return base_type.first && IsTypeNullable(base_type.second.words, usedefs); + } + case SpvOpTypeStruct: { + for (size_t elementIndex = 2; elementIndex < instruction.size(); + ++elementIndex) { + auto element = usedefs.FindDef(instruction[elementIndex]); + if (!element.first || !IsTypeNullable(element.second.words, usedefs)) + return false; + } + return true; + } + default: + return false; + } +} + template <> bool idUsage::isValid(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); - if (!resultType.first) return false; - switch (resultType.second.opcode) { - default: { - spvCheck(!spvOpcodeIsBasicTypeNullable(resultType.second.opcode), - DIAG(resultTypeIndex) << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "' cannot be null."; - return false); - } break; - case SpvOpTypeVector: { - auto type = usedefs_.FindDef(resultType.second.words[2]); - assert(type.first); - spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), - DIAG(resultTypeIndex) - << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "'s vector component type cannot be null."; - return false); - } break; - case SpvOpTypeArray: { - auto type = usedefs_.FindDef(resultType.second.words[2]); - assert(type.first); - spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), - DIAG(resultTypeIndex) << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "'s array element type cannot be null."; - return false); - } break; - case SpvOpTypeMatrix: { - auto columnType = usedefs_.FindDef(resultType.second.words[2]); - assert(columnType.first); - auto type = usedefs_.FindDef(columnType.second.words[2]); - assert(type.first); - spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), - DIAG(resultTypeIndex) - << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "'s matrix component type cna not be null."; - return false); - } break; - case SpvOpTypeStruct: { - for (size_t elementIndex = 2; - elementIndex < resultType.second.words.size(); ++elementIndex) { - auto element = usedefs_.FindDef(resultType.second.words[elementIndex]); - assert(element.first); - spvCheck(!spvOpcodeIsBasicTypeNullable(element.second.opcode), - DIAG(resultTypeIndex) - << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "'s struct element type cannot be null."; - return false); - } - } break; + if (!resultType.first || !IsTypeNullable(resultType.second.words, usedefs_)) { + DIAG(resultTypeIndex) << "OpConstantNull Result Type '" + << inst->words[resultTypeIndex] + << "' cannot have a null value."; + return false; } return true; } diff --git a/test/ValidateID.cpp b/test/ValidateID.cpp index 90778cc..d2d6a26 100644 --- a/test/ValidateID.cpp +++ b/test/ValidateID.cpp @@ -612,15 +612,21 @@ TEST_F(ValidateID, OpConstantNullGood) { %22 = OpConstantNull %21 %23 = OpTypeStruct %3 %5 %1 %24 = OpConstantNull %23 +%26 = OpTypeArray %17 %25 +%27 = OpConstantNull %26 +%28 = OpTypeStruct %7 %26 %26 %1 +%29 = OpConstantNull %28 )"; CHECK(spirv, SPV_SUCCESS); } + TEST_F(ValidateID, OpConstantNullBasicBad) { const char* spirv = R"( %1 = OpTypeVoid %2 = OpConstantNull %1)"; CHECK(spirv, SPV_ERROR_INVALID_ID); } + TEST_F(ValidateID, OpConstantNullArrayBad) { const char* spirv = R"( %2 = OpTypeInt 32 0 @@ -630,6 +636,7 @@ TEST_F(ValidateID, OpConstantNullArrayBad) { %6 = OpConstantNull %5)"; CHECK(spirv, SPV_ERROR_INVALID_ID); } + TEST_F(ValidateID, OpConstantNullStructBad) { const char* spirv = R"( %2 = OpTypeSampler @@ -638,6 +645,14 @@ TEST_F(ValidateID, OpConstantNullStructBad) { CHECK(spirv, SPV_ERROR_INVALID_ID); } +TEST_F(ValidateID, OpConstantNullRuntimeArrayBad) { + const char* spirv = R"( +%bool = OpTypeBool +%array = OpTypeRuntimeArray %bool +%null = OpConstantNull %array)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + TEST_F(ValidateID, OpSpecConstantTrueGood) { const char* spirv = R"( %1 = OpTypeBool -- 2.7.4