layers: Make passing a non-SPIRV shader to CreateShaderModule an error
authorChris Forbes <chrisf@ijw.co.nz>
Thu, 17 Sep 2015 23:40:23 +0000 (11:40 +1200)
committerChris Forbes <chrisf@ijw.co.nz>
Sun, 20 Sep 2015 00:38:45 +0000 (12:38 +1200)
Various real drivers are not tolerant of the wrapped-GLSL-blob hack that
was used early in Vulkan development. Rather than simply bypassing most
of the validation when one of these blobs is seen, make it an error.

Structurally, move the validation out of the shader_module ctor in
preparation for the callback being able to signal whether to bail out.

V2: Adjustment for descriptor validation landing

Signed-off-by: Chris Forbes <chrisf@ijw.co.nz>
Reviewed-by: Tobin Ehlis <tobin@lunarg.com>
layers/shader_checker.cpp

index b4d1a59c9a4b9f6d858c3c9e052fbea7e1073991..5b2076fbedc45c8cee5b5c0a739511387c7faf0b 100644 (file)
@@ -180,6 +180,17 @@ build_type_def_index(std::vector<unsigned> const &words, std::unordered_map<unsi
     }
 }
 
+
+bool
+shader_is_spirv(VkShaderModuleCreateInfo const *pCreateInfo)
+{
+    uint32_t *words = (uint32_t *)pCreateInfo->pCode;
+    uint32_t sizeInWords = pCreateInfo->codeSize / sizeof(uint32_t);
+
+    /* Just validate that the header makes sense. */
+    return sizeInWords >= 5 && words[0] == spv::MagicNumber && words[1] == spv::Version;
+}
+
 struct shader_module {
     /* the spirv image itself */
     std::vector<uint32_t> words;
@@ -187,20 +198,10 @@ struct shader_module {
      * trees requires jumping all over the instruction stream.
      */
     std::unordered_map<unsigned, unsigned> type_def_index;
-    bool is_spirv;
 
-    shader_module(VkDevice dev, VkShaderModuleCreateInfo const *pCreateInfo) :
+    shader_module(VkShaderModuleCreateInfo const *pCreateInfo) :
         words((uint32_t *)pCreateInfo->pCode, (uint32_t *)pCreateInfo->pCode + pCreateInfo->codeSize / sizeof(uint32_t)),
-        type_def_index(),
-        is_spirv(true) {
-
-        if (words.size() < 5 || words[0] != spv::MagicNumber || words[1] != spv::Version) {
-            log_msg(mdd(dev), VK_DBG_REPORT_WARN_BIT, VK_OBJECT_TYPE_DEVICE, /* dev */ 0, 0, SHADER_CHECKER_NON_SPIRV_SHADER, "SC",
-                    "Shader is not SPIR-V, most checks will not be possible");
-            is_spirv = false;
-            return;
-        }
-
+        type_def_index() {
 
         build_type_def_index(words, type_def_index);
     }
@@ -627,11 +628,19 @@ VK_LAYER_EXPORT VkResult VKAPI vkCreateShaderModule(
         const VkShaderModuleCreateInfo *pCreateInfo,
         VkShaderModule *pShaderModule)
 {
+    /* Protect the driver from non-SPIRV shaders */
+    if (!shader_is_spirv(pCreateInfo)) {
+        log_msg(mdd(device), VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_DEVICE,
+                /* dev */ 0, 0, SHADER_CHECKER_NON_SPIRV_SHADER, "SC",
+                "Shader is not SPIR-V");
+        return VK_ERROR_VALIDATION_FAILED;
+    }
+
     VkResult res = get_dispatch_table(shader_checker_device_table_map, device)->CreateShaderModule(device, pCreateInfo, pShaderModule);
 
     if (res == VK_SUCCESS) {
         loader_platform_thread_lock_mutex(&globalLock);
-        shader_module_map[pShaderModule->handle] = new shader_module(device, pCreateInfo);
+        shader_module_map[pShaderModule->handle] = new shader_module(pCreateInfo);
         loader_platform_thread_unlock_mutex(&globalLock);
     }
     return res;
@@ -1037,29 +1046,26 @@ validate_graphics_pipeline(VkDevice dev, VkGraphicsPipelineCreateInfo const *pCr
                 shaders[pStage->stage] = shader->module;
 
                 /* validate descriptor set layout against what the spirv module actually uses */
-                if (shader->module->is_spirv) {
-                    std::map<std::pair<unsigned, unsigned>, interface_var> descriptor_uses;
-                    collect_interface_by_descriptor_slot(dev, shader->module, spv::StorageClassUniform,
-                            descriptor_uses);
-
-                    auto layout = pCreateInfo->layout.handle ?
-                        pipeline_layout_map[pCreateInfo->layout.handle] : nullptr;
-
-                    for (auto it = descriptor_uses.begin(); it != descriptor_uses.end(); it++) {
-
-                        /* find the matching binding */
-                        auto binding = find_descriptor_binding(layout, it->first);
-
-                        if (binding == nullptr) {
-                            char type_name[1024];
-                            describe_type(type_name, shader->module, it->second.type_id);
-                            log_msg(mdd(dev), VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_DEVICE, /*dev*/0, 0,
-                                    SHADER_CHECKER_MISSING_DESCRIPTOR, "SC",
-                                    "Shader uses descriptor slot %u.%u (used as type `%s`) but not declared in pipeline layout",
-                                    it->first.first, it->first.second, type_name);
-                            pass = false;
-                        }
-
+                std::map<std::pair<unsigned, unsigned>, interface_var> descriptor_uses;
+                collect_interface_by_descriptor_slot(dev, shader->module, spv::StorageClassUniform,
+                        descriptor_uses);
+
+                auto layout = pCreateInfo->layout.handle ?
+                    pipeline_layout_map[pCreateInfo->layout.handle] : nullptr;
+
+                for (auto it = descriptor_uses.begin(); it != descriptor_uses.end(); it++) {
+
+                    /* find the matching binding */
+                    auto binding = find_descriptor_binding(layout, it->first);
+
+                    if (binding == nullptr) {
+                        char type_name[1024];
+                        describe_type(type_name, shader->module, it->second.type_id);
+                        log_msg(mdd(dev), VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_DEVICE, /*dev*/0, 0,
+                                SHADER_CHECKER_MISSING_DESCRIPTOR, "SC",
+                                "Shader uses descriptor slot %u.%u (used as type `%s`) but not declared in pipeline layout",
+                                it->first.first, it->first.second, type_name);
+                        pass = false;
                     }
                 }
             }
@@ -1075,7 +1081,7 @@ validate_graphics_pipeline(VkDevice dev, VkGraphicsPipelineCreateInfo const *pCr
         pass = validate_vi_consistency(dev, vi) && pass;
     }
 
-    if (shaders[VK_SHADER_STAGE_VERTEX] && shaders[VK_SHADER_STAGE_VERTEX]->is_spirv) {
+    if (shaders[VK_SHADER_STAGE_VERTEX]) {
         pass = validate_vi_against_vs_inputs(dev, vi, shaders[VK_SHADER_STAGE_VERTEX]) && pass;
     }
 
@@ -1091,18 +1097,16 @@ validate_graphics_pipeline(VkDevice dev, VkGraphicsPipelineCreateInfo const *pCr
     for (; producer != VK_SHADER_STAGE_FRAGMENT && consumer <= VK_SHADER_STAGE_FRAGMENT; consumer++) {
         assert(shaders[producer]);
         if (shaders[consumer]) {
-            if (shaders[producer]->is_spirv && shaders[consumer]->is_spirv) {
-                pass = validate_interface_between_stages(dev,
-                                                         shaders[producer], shader_stage_attribs[producer].name,
-                                                         shaders[consumer], shader_stage_attribs[consumer].name,
-                                                         shader_stage_attribs[consumer].arrayed_input) && pass;
-            }
+            pass = validate_interface_between_stages(dev,
+                                                     shaders[producer], shader_stage_attribs[producer].name,
+                                                     shaders[consumer], shader_stage_attribs[consumer].name,
+                                                     shader_stage_attribs[consumer].arrayed_input) && pass;
 
             producer = consumer;
         }
     }
 
-    if (shaders[VK_SHADER_STAGE_FRAGMENT] && shaders[VK_SHADER_STAGE_FRAGMENT]->is_spirv && rp) {
+    if (shaders[VK_SHADER_STAGE_FRAGMENT] && rp) {
         pass = validate_fs_outputs_against_render_pass(dev, shaders[VK_SHADER_STAGE_FRAGMENT], rp, pCreateInfo->subpass) && pass;
     }