From 3e08a3f71896c1c15bdcbd42a9ba28ddd0075427 Mon Sep 17 00:00:00 2001 From: Andrey Tuganov Date: Thu, 23 Nov 2017 12:11:15 -0500 Subject: [PATCH] Add validation checks for Execution Model Currently checks that these instructions are called from entry points with Fragment execution model. OpImageImplicit* OpImageQueryLod OpKill --- source/val/function.cpp | 22 +++++++++ source/val/function.h | 34 ++++++++++++++ source/val/validation_state.cpp | 8 ++++ source/val/validation_state.h | 17 +++++-- source/validate_cfg.cpp | 5 +++ source/validate_id.cpp | 31 ++++++++++++- source/validate_image.cpp | 10 +++++ test/val/val_image_test.cpp | 99 +++++++++++++++++++++++++++++++---------- 8 files changed, 199 insertions(+), 27 deletions(-) diff --git a/source/val/function.cpp b/source/val/function.cpp index 91352b2..d7ac741 100644 --- a/source/val/function.cpp +++ b/source/val/function.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -348,4 +349,25 @@ int Function::GetBlockDepth(BasicBlock* bb) { return block_depth_[bb]; } +bool Function::IsCompatibleWithExecutionModel(SpvExecutionModel model, + std::string* reason) const { + bool is_compatible = true; + std::stringstream ss_reason; + + for (const auto& kv : execution_model_limitations_) { + if (kv.first != model) { + if (!reason) + return false; + is_compatible = false; + ss_reason << kv.second << "\n"; + } + } + + if (!is_compatible && reason) { + *reason = ss_reason.str(); + } + + return is_compatible; +} + } /// namespace libspirv diff --git a/source/val/function.h b/source/val/function.h index f7856b3..4d53b04 100644 --- a/source/val/function.h +++ b/source/val/function.h @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include #include @@ -200,6 +202,30 @@ class Function { /// Prints a directed graph of the CFG of the current funciton void PrintBlocks() const; + /// Registers execution model limitation such as "Feature X is only available + /// with Execution Model Y". Only the first message per model type is + /// registered. + void RegisterExecutionModelLimitation(SpvExecutionModel model, + const std::string& message) { + execution_model_limitations_.emplace(model, message); + } + + /// Returns true if the given execution model passes the limitations stored in + /// execution_model_limitations_. Returns false otherwise and fills optional + /// |reason| parameter. + bool IsCompatibleWithExecutionModel(SpvExecutionModel model, + std::string* reason = nullptr) const; + + // Inserts id to the set of functions called from this function. + void AddFunctionCallTarget(uint32_t call_target_id) { + function_call_targets_.insert(call_target_id); + } + + // Returns a set with ids of all functions called from this function. + const std::set function_call_targets() const { + return function_call_targets_; + } + private: // Computes the representation of the augmented CFG. // Populates augmented_successors_map_ and augmented_predecessors_map_. @@ -310,6 +336,14 @@ class Function { /// Stores the control flow nesting depth of a given basic block std::unordered_map block_depth_; + + /// Stores execution model limitations imposed by instructions used within the + /// function. The string contains message explaining why the limitation was + /// imposed. + std::map execution_model_limitations_; + + /// Stores ids of all functions called from this function. + std::set function_call_targets_; }; } /// namespace libspirv diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index cd4dafd..02b3646 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -264,6 +264,13 @@ const Function& ValidationState_t::current_function() const { return module_functions_.back(); } +const Function* ValidationState_t::function(uint32_t id) const { + const auto it = id_to_function_.find(id); + if (it == id_to_function_.end()) + return nullptr; + return it->second; +} + bool ValidationState_t::in_function_body() const { return in_function_; } bool ValidationState_t::in_block() const { @@ -352,6 +359,7 @@ spv_result_t ValidationState_t::RegisterFunction( in_function_ = true; module_functions_.emplace_back(id, ret_type_id, function_control, function_type_id); + id_to_function_.emplace(id, ¤t_function()); // TODO(umar): validate function type and type_id diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 8f261dc..d57eac7 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -140,6 +140,9 @@ class ValidationState_t { Function& current_function(); const Function& current_function() const; + /// Returns function state with the given id, or nullptr if no such function. + const Function* function(uint32_t id) const; + /// Returns true if the called after a function instruction but before the /// function end instruction bool in_function_body() const; @@ -173,6 +176,7 @@ class ValidationState_t { /// Inserts an to the set of functions that are target of OpFunctionCall. void AddFunctionCallTarget(const uint32_t id) { function_call_targets_.insert(id); + current_function().AddFunctionCallTarget(id); } /// Returns whether or not a function is the target of OpFunctionCall. @@ -433,7 +437,9 @@ class ValidationState_t { /// The section of the code being processed ModuleLayoutSection current_layout_section_; - /// A list of functions in the module + /// A list of functions in the module. + /// Pointers to objects in this container are guaranteed to be stable and + /// valid until the end of lifetime of the validation state. std::deque module_functions_; /// Capabilities declared in the module @@ -443,6 +449,8 @@ class ValidationState_t { libspirv::ExtensionSet module_extensions_; /// List of all instructions in the order they appear in the binary + /// Pointers to objects in this container are guaranteed to be stable and + /// valid until the end of lifetime of the validation state. std::deque ordered_instructions_; /// Instructions that can be referenced by Ids @@ -489,9 +497,12 @@ class ValidationState_t { /// NOTE: See correspoding getter functions bool in_function_; - // The state of optional features. These are determined by capabilities - // declared by the module. + /// The state of optional features. These are determined by capabilities + /// declared by the module. Feature features_; + + /// Maps function ids to function stat objects. + std::unordered_map id_to_function_; }; } /// namespace libspirv diff --git a/source/validate_cfg.cpp b/source/validate_cfg.cpp index 45a6c95..d601e12 100644 --- a/source/validate_cfg.cpp +++ b/source/validate_cfg.cpp @@ -416,6 +416,11 @@ spv_result_t CfgPass(ValidationState_t& _, case SpvOpReturnValue: case SpvOpUnreachable: _.current_function().RegisterBlockEnd(vector(), opcode); + if (opcode == SpvOpKill) { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpKill requires Fragment execution model"); + } break; default: break; diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 5f0e63f..e936ea6 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "val/validation_state.h" using libspirv::Decoration; +using libspirv::Function; using libspirv::ValidationState_t; using std::function; using std::ignore; @@ -288,7 +290,7 @@ bool idUsage::isValid(const spv_instruction_t* inst, return false; } // don't check kernel function signatures - auto executionModel = inst->words[1]; + const SpvExecutionModel executionModel = SpvExecutionModel(inst->words[1]); if (executionModel != SpvExecutionModelKernel) { // TODO: Check the entry point signature is void main(void), may be subject // to change @@ -300,6 +302,33 @@ bool idUsage::isValid(const spv_instruction_t* inst, return false; } } + + std::stack call_stack; + std::set visited; + call_stack.push(entryPoint->id()); + while (!call_stack.empty()) { + const uint32_t called_func_id = call_stack.top(); + call_stack.pop(); + if (!visited.insert(called_func_id).second) continue; + + const Function* called_func = module_.function(called_func_id); + assert(called_func); + + std::string reason; + if (!called_func->IsCompatibleWithExecutionModel(executionModel, &reason)) { + DIAG(entryPointIndex) + << "OpEntryPoint Entry Point '" << inst->words[entryPointIndex] + << "'s callgraph contains function " << called_func_id + << ", which cannot be used with the current execution model:\n" + << reason; + return false; + } + + for (uint32_t new_call : called_func->function_call_targets()) { + call_stack.push(new_call); + } + } + auto returnType = module_.FindDef(entryPoint->type_id()); if (!returnType || SpvOpTypeVoid != returnType->opcode()) { DIAG(entryPointIndex) << "OpEntryPoint Entry Point '" diff --git a/source/validate_image.cpp b/source/validate_image.cpp index bc6f70a..48179aa 100644 --- a/source/validate_image.cpp +++ b/source/validate_image.cpp @@ -587,6 +587,12 @@ spv_result_t ImagePass(ValidationState_t& _, const SpvOp opcode = static_cast(inst->opcode); const uint32_t result_type = inst->type_id; + if (IsImplicitLod(opcode)) { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "ImplicitLod instructions require Fragment execution model"); + } + switch (opcode) { case SpvOpSampledImage: { if (_.GetIdOpcode(result_type) != SpvOpTypeSampledImage) { @@ -1276,6 +1282,10 @@ spv_result_t ImagePass(ValidationState_t& _, } case SpvOpImageQueryLod: { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpImageQueryLod requires Fragment execution model"); + if (!_.IsFloatVectorType(result_type)) { return _.diag(SPV_ERROR_INVALID_DATA) << "Expected Result Type to be float vector type: " diff --git a/test/val/val_image_test.cpp b/test/val/val_image_test.cpp index 38aae0f..b19843b 100644 --- a/test/val/val_image_test.cpp +++ b/test/val/val_image_test.cpp @@ -14,6 +14,7 @@ // Tests for unique type declaration rules validator. +#include #include #include "gmock/gmock.h" @@ -29,21 +30,24 @@ using ValidateImage = spvtest::ValidateBase; std::string GenerateShaderCode( const std::string& body, - const std::string& capabilities_and_extensions = "") { - const std::string capabilities = -R"( + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "Fragment") { + std::stringstream ss; + ss << R"( OpCapability Shader OpCapability InputAttachment OpCapability ImageGatherExtended OpCapability MinLod OpCapability Sampled1D OpCapability SampledRect -OpCapability ImageQuery)"; +OpCapability ImageQuery +)"; - const std::string after_extension_before_body = -R"( -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + + ss << R"( %void = OpTypeVoid %func = OpTypeFunction %void %bool = OpTypeBool @@ -199,32 +203,34 @@ OpEntryPoint Fragment %main "main" %uniform_sampler = OpVariable %ptr_sampler UniformConstant %main = OpFunction %void None %func -%main_entry = OpLabel)"; +%main_entry = OpLabel +)"; - const std::string after_body = -R"( + ss << body; + + ss << R"( OpReturn OpFunctionEnd)"; - return capabilities + capabilities_and_extensions + - after_extension_before_body + body + after_body; + return ss.str(); } std::string GenerateKernelCode( const std::string& body, const std::string& capabilities_and_extensions = "") { - const std::string capabilities = -R"( + std::stringstream ss; + ss << R"( OpCapability Addresses OpCapability Kernel OpCapability Linkage OpCapability ImageQuery OpCapability ImageGatherExtended OpCapability InputAttachment -OpCapability SampledRect)"; +OpCapability SampledRect +)"; - const std::string after_extension_before_body = -R"( + ss << capabilities_and_extensions; + ss << R"( OpMemoryModel Physical32 OpenCL %void = OpTypeVoid %func = OpTypeFunction %void @@ -293,15 +299,15 @@ OpMemoryModel Physical32 OpenCL %uniform_sampler = OpVariable %ptr_sampler UniformConstant %main = OpFunction %void None %func -%main_entry = OpLabel)"; +%main_entry = OpLabel +)"; - const std::string after_body = -R"( + ss << body; + ss << R"( OpReturn OpFunctionEnd)"; - return capabilities + capabilities_and_extensions + - after_extension_before_body + body + after_body; + return ss.str(); } TEST_F(ValidateImage, SampledImageSuccess) { @@ -3037,4 +3043,51 @@ TEST_F(ValidateImage, QuerySamplesNotMultisampled) { "Image 'MS' must be 1: ImageQuerySamples")); } +TEST_F(ValidateImage, QueryLodWrongExecutionModel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "OpImageQueryLod requires Fragment execution model")); +} + +TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) { + const std::string body = R"( +%call_ret = OpFunctionCall %void %my_func +OpReturn +OpFunctionEnd +%my_func = OpFunction %void None %func +%my_func_entry = OpLabel +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "OpImageQueryLod requires Fragment execution model")); +} + +TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "ImplicitLod instructions require Fragment execution model")); +} + } // anonymous namespace -- 2.7.4