shader_checker: add validation of interface between vs and fs
authorChris Forbes <chrisf@ijw.co.nz>
Tue, 7 Apr 2015 22:19:16 +0000 (10:19 +1200)
committerChris Forbes <chrisf@ijw.co.nz>
Thu, 16 Apr 2015 22:13:59 +0000 (10:13 +1200)
layers/shader_checker.cpp

index 6d88b23..ad375df 100644 (file)
@@ -26,6 +26,7 @@
 #include <assert.h>
 #include <map>
 #include <unordered_map>
+#include <map>
 #include <vector>
 #include "loader_platform.h"
 #include "vk_dispatch_table_helper.h"
@@ -250,6 +251,50 @@ VK_LAYER_EXPORT VkResult VKAPI vkCreateShader(VkDevice device, const VkShaderCre
 }
 
 
+static void
+validate_interface_between_stages(shader_source const *producer, char const *producer_name,
+                                  shader_source const *consumer, char const *consumer_name)
+{
+    std::map<uint32_t, interface_var> outputs;
+    std::map<uint32_t, interface_var> inputs;
+
+    std::map<uint32_t, interface_var> builtin_outputs;
+    std::map<uint32_t, interface_var> builtin_inputs;
+
+    printf("Begin validate_interface_between_stages %s -> %s\n",
+           producer_name, consumer_name);
+
+    collect_interface_by_location(producer, spv::StorageOutput, outputs, builtin_outputs);
+    collect_interface_by_location(consumer, spv::StorageInput, inputs, builtin_inputs);
+
+    auto a_it = outputs.begin();
+    auto b_it = inputs.begin();
+
+    /* maps sorted by key (location); walk them together to find mismatches */
+    while (a_it != outputs.end() || b_it != inputs.end()) {
+        if (b_it == inputs.end() || a_it->first < b_it->first) {
+            printf("  WARN: %s writes to output location %d which is not consumed by %s\n",
+                   producer_name, a_it->first, consumer_name);
+            a_it++;
+        }
+        else if (a_it == outputs.end() || a_it->first > b_it->first) {
+            printf("  ERR: %s consumes input location %d which is not written by %s\n",
+                   consumer_name, b_it->first, producer_name);
+            b_it++;
+        }
+        else {
+            printf("  OK: match on location %d\n",
+                   a_it->first);
+            /* TODO: typecheck */
+            a_it++;
+            b_it++;
+        }
+    }
+
+    printf("End validate_interface_between_stages\n");
+}
+
+
 VK_LAYER_EXPORT VkResult VKAPI vkCreateGraphicsPipeline(VkDevice device,
                                                              const VkGraphicsPipelineCreateInfo *pCreateInfo,
                                                              VkPipeline *pPipeline)
@@ -288,6 +333,11 @@ VK_LAYER_EXPORT VkResult VKAPI vkCreateGraphicsPipeline(VkDevice device,
 
     printf("Pipeline: vi=%p vs=%p fs=%p cb=%p\n", vi, vs_source, fs_source, cb);
 
+    if (vs_source && fs_source) {
+        validate_interface_between_stages(vs_source, "vertex shader",
+                                          fs_source, "fragment shader");
+    }
+
     VkLayerDispatchTable *pTable = tableMap[(VkBaseLayerObject *)device];
     VkResult res = pTable->CreateGraphicsPipeline(device, pCreateInfo, pPipeline);
     return res;