shader_checker: first pass at typechecks
authorChris Forbes <chrisf@ijw.co.nz>
Thu, 9 Apr 2015 23:41:20 +0000 (11:41 +1200)
committerChris Forbes <chrisf@ijw.co.nz>
Thu, 16 Apr 2015 22:13:59 +0000 (10:13 +1200)
layers/shader_checker.cpp

index a15a0ee..8b49660 100644 (file)
 static std::unordered_map<void *, VkLayerDispatchTable *> tableMap;
 
 
+
+static void
+build_type_def_index(std::vector<unsigned> const &words, std::unordered_map<unsigned, unsigned> &type_def_index)
+{
+    unsigned int const *code = (unsigned int const *)&words[0];
+    size_t size = words.size();
+
+    unsigned word = 5;
+    while (word < size) {
+        unsigned opcode = code[word] & 0x0ffffu;
+        unsigned oplen = (code[word] & 0xffff0000u) >> 16;
+
+        switch (opcode) {
+        case spv::OpTypeVoid:
+        case spv::OpTypeBool:
+        case spv::OpTypeInt:
+        case spv::OpTypeFloat:
+        case spv::OpTypeVector:
+        case spv::OpTypeMatrix:
+        case spv::OpTypeSampler:
+        case spv::OpTypeFilter:
+        case spv::OpTypeArray:
+        case spv::OpTypeRuntimeArray:
+        case spv::OpTypeStruct:
+        case spv::OpTypeOpaque:
+        case spv::OpTypePointer:
+        case spv::OpTypeFunction:
+        case spv::OpTypeEvent:
+        case spv::OpTypeDeviceEvent:
+        case spv::OpTypeReserveId:
+        case spv::OpTypeQueue:
+        case spv::OpTypePipe:
+            type_def_index[code[word+1]] = word;
+            break;
+
+        default:
+            /* We only care about type definitions */
+            break;
+        }
+
+        word += oplen;
+    }
+}
+
 struct shader_source {
+    /* the spirv image itself */
     std::vector<uint32_t> words;
+    /* a mapping of <id> to the first word of its def. this is useful because walking type
+     * trees requires jumping all over the instruction stream.
+     */
+    std::unordered_map<unsigned, unsigned> type_def_index;
 
     shader_source(VkShaderCreateInfo const *pCreateInfo) :
         words((uint32_t *)pCreateInfo->pCode, (uint32_t *)pCreateInfo->pCode + pCreateInfo->codeSize / sizeof(uint32_t)) {
+
+        build_type_def_index(words, type_def_index);
     }
 };
 
@@ -150,6 +201,149 @@ VK_LAYER_EXPORT VkResult VKAPI vkGetGlobalExtensionInfo(
 }
 
 
+static char const *
+storage_class_name(unsigned sc)
+{
+    switch (sc) {
+    case spv::StorageInput: return "input";
+    case spv::StorageOutput: return "output";
+    case spv::StorageConstantUniform: return "const uniform";
+    case spv::StorageUniform: return "uniform";
+    case spv::StorageWorkgroupLocal: return "workgroup local";
+    case spv::StorageWorkgroupGlobal: return "workgroup global";
+    case spv::StoragePrivateGlobal: return "private global";
+    case spv::StorageFunction: return "function";
+    case spv::StorageGeneric: return "generic";
+    case spv::StoragePrivate: return "private";
+    case spv::StorageAtomicCounter: return "atomic counter";
+    default: return "unknown";
+    }
+}
+
+
+/* returns ptr to null terminator */
+static char *
+describe_type(char *dst, shader_source const *src, unsigned type)
+{
+    auto type_def_it = src->type_def_index.find(type);
+
+    if (type_def_it == src->type_def_index.end()) {
+        return dst + sprintf(dst, "undef");
+    }
+
+    unsigned int const *code = (unsigned int const *)&src->words[type_def_it->second];
+    unsigned opcode = code[0] & 0x0ffffu;
+    switch (opcode) {
+        case spv::OpTypeBool:
+            return dst + sprintf(dst, "bool");
+        case spv::OpTypeInt:
+            return dst + sprintf(dst, "%cint%d", code[3] ? 's' : 'u', code[2]);
+        case spv::OpTypeFloat:
+            return dst + sprintf(dst, "float%d", code[2]);
+        case spv::OpTypeVector:
+            dst += sprintf(dst, "vec%d of ", code[3]);
+            return describe_type(dst, src, code[2]);
+        case spv::OpTypeMatrix:
+            dst += sprintf(dst, "mat%d of ", code[3]);
+            return describe_type(dst, src, code[2]);
+        case spv::OpTypeArray:
+            dst += sprintf(dst, "arr[%d] of ", code[3]);
+            return describe_type(dst, src, code[2]);
+        case spv::OpTypePointer:
+            dst += sprintf(dst, "ptr to %s ", storage_class_name(code[2]));
+            return describe_type(dst, src, code[3]);
+        case spv::OpTypeStruct:
+            {
+                unsigned oplen = code[0] >> 16;
+                dst += sprintf(dst, "struct of (");
+                for (int i = 2; i < oplen; i++) {
+                    dst = describe_type(dst, src, code[i]);
+                    dst += sprintf(dst, i == oplen-1 ? ")" : ", ");
+                }
+                return dst;
+            }
+        default:
+            return dst + sprintf(dst, "oddtype");
+    }
+}
+
+
+static bool
+types_match(shader_source const *a, shader_source const *b, unsigned a_type, unsigned b_type)
+{
+    auto a_type_def_it = a->type_def_index.find(a_type);
+    auto b_type_def_it = b->type_def_index.find(b_type);
+
+    if (a_type_def_it == a->type_def_index.end()) {
+        printf("ERR: can't find def for type %d in producing shader %p; SPIRV probably invalid.\n",
+                a_type, a);
+        return false;
+    }
+
+    if (b_type_def_it == b->type_def_index.end()) {
+        printf("ERR: can't find def for type %d in consuming shader %p; SPIRV probably invalid.\n",
+                b_type, b);
+        return false;
+    }
+
+    /* walk two type trees together, and complain about differences */
+    unsigned int const *a_code = (unsigned int const *)&a->words[a_type_def_it->second];
+    unsigned int const *b_code = (unsigned int const *)&b->words[b_type_def_it->second];
+
+    unsigned a_opcode = a_code[0] & 0x0ffffu;
+    unsigned b_opcode = b_code[0] & 0x0ffffu;
+
+    if (a_opcode != b_opcode) {
+        printf("  - FAIL: type def opcodes differ: %d vs %d\n", a_opcode, b_opcode);
+        return false;
+    }
+
+    switch (a_opcode) {
+        case spv::OpTypeBool:
+            return true;
+        case spv::OpTypeInt:
+            /* match on width, signedness */
+            return a_code[2] == b_code[2] && a_code[3] == b_code[3];
+        case spv::OpTypeFloat:
+            /* match on width */
+            return a_code[2] == b_code[2];
+        case spv::OpTypeVector:
+        case spv::OpTypeMatrix:
+        case spv::OpTypeArray:
+            /* match on element type, count. these all have the same layout */
+            return types_match(a, b, a_code[2], b_code[2]) && a_code[3] == b_code[3];
+        case spv::OpTypeStruct:
+            /* match on all element types */
+            {
+                unsigned a_len = a_code[0] >> 16;
+                unsigned b_len = b_code[0] >> 16;
+
+                if (a_len != b_len) {
+                    return false;   /* structs cannot match if member counts differ */
+                }
+
+                for (int i = 2; i < a_len; i++) {
+                    if (!types_match(a, b, a_code[i], b_code[i])) {
+                        return false;
+                    }
+                }
+
+                return true;
+            }
+        case spv::OpTypePointer:
+            /* match on pointee type. storage class is expected to differ */
+            return types_match(a, b, a_code[3], b_code[3]);
+
+        default:
+            /* remaining types are CLisms, or may not appear in the interfaces we
+             * are interested in. Just claim no match.
+             */
+            return false;
+
+    }
+}
+
+
 static int
 value_or_default(std::unordered_map<unsigned, unsigned> const &map, unsigned id, int def)
 {
@@ -283,9 +477,18 @@ validate_interface_between_stages(shader_source const *producer, char const *pro
             b_it++;
         }
         else {
-            printf("  OK: match on location %d\n",
-                   a_it->first);
-            /* TODO: typecheck */
+            if (types_match(producer, consumer, a_it->second.type_id, b_it->second.type_id)) {
+                printf("  OK: match on location %d\n", a_it->first);
+            }
+            else {
+                char producer_type[1024];
+                char consumer_type[1024];
+                describe_type(producer_type, producer, a_it->second.type_id);
+                describe_type(consumer_type, consumer, b_it->second.type_id);
+
+                printf("  ERR: type mismatch on location %d: '%s' vs '%s'\n", a_it->first,
+                       producer_type, consumer_type);
+            }
             a_it++;
             b_it++;
         }