vulkan: Properly filter by api in enum_to_str
authorFaith Ekstrand <faith.ekstrand@collabora.com>
Thu, 9 Feb 2023 00:37:04 +0000 (18:37 -0600)
committerMarge Bot <emma+marge@anholt.net>
Fri, 17 Feb 2023 03:42:34 +0000 (03:42 +0000)
This switches us to using get_all_required() for figuring out which
enum types we care about and then carefully filtering every value as
needed.  We also add a number field to Extension so we keep all the
extension XML parsing in one place.

Acked-By: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21225>

src/vulkan/util/gen_enum_to_str.py
src/vulkan/util/vk_extensions.py

index 4bdb842..a8e2197 100644 (file)
@@ -28,6 +28,7 @@ import textwrap
 import xml.etree.ElementTree as et
 
 from mako.template import Template
+from vk_extensions import Extension, filter_api, get_all_required
 
 COPYRIGHT = textwrap.dedent(u"""\
     * Copyright © 2017 Intel Corporation
@@ -370,30 +371,59 @@ def parse_xml(enum_factory, ext_factory, struct_factory, bitmask_factory,
     """
 
     xml = et.parse(filename)
+    api = 'vulkan'
+
+    required_types = get_all_required(xml, 'type', api)
 
     for enum_type in xml.findall('./enums[@type="enum"]'):
-        enum = enum_factory(enum_type.attrib['name'])
+        if not filter_api(enum_type, api):
+            continue
+
+        type_name = enum_type.attrib['name']
+        if not type_name in required_types:
+            continue
+
+        enum = enum_factory(type_name)
         for value in enum_type.findall('./enum'):
-            enum.add_value_from_xml(value)
+            if filter_api(value, api):
+                enum.add_value_from_xml(value)
 
     # For bitmask we only add the Enum selected for convenience.
     for enum_type in xml.findall('./enums[@type="bitmask"]'):
+        if not filter_api(enum_type, api):
+            continue
+
+        type_name = enum_type.attrib['name']
+        if not type_name in required_types:
+            continue
+
         bitwidth = int(enum_type.attrib.get('bitwidth', 32))
-        enum = bitmask_factory(enum_type.attrib['name'], bitwidth=bitwidth)
+        enum = bitmask_factory(type_name, bitwidth=bitwidth)
         for value in enum_type.findall('./enum'):
-            enum.add_value_from_xml(value)
+            if filter_api(value, api):
+                enum.add_value_from_xml(value)
 
-    for value in xml.findall('./feature/require/enum[@extends]'):
-        extends = value.attrib['extends']
-        enum = enum_factory.get(extends)
-        if enum is not None:
-            enum.add_value_from_xml(value)
-        enum = bitmask_factory.get(extends)
-        if enum is not None:
-            enum.add_value_from_xml(value)
+    for feature in xml.findall('./feature'):
+        if not api in feature.attrib['api'].split(','):
+            continue
+
+        for value in feature.findall('./require/enum[@extends]'):
+            extends = value.attrib['extends']
+            enum = enum_factory.get(extends)
+            if enum is not None:
+                enum.add_value_from_xml(value)
+            enum = bitmask_factory.get(extends)
+            if enum is not None:
+                enum.add_value_from_xml(value)
 
     for struct_type in xml.findall('./types/type[@category="struct"]'):
+        if not filter_api(struct_type, api):
+            continue
+
         name = struct_type.attrib['name']
+        if name not in required_types:
+            continue
+
         stype = struct_get_stype(struct_type)
         if stype is not None:
             struct_factory(name, stype=stype)
@@ -404,26 +434,31 @@ def parse_xml(enum_factory, ext_factory, struct_factory, bitmask_factory,
         define = platform.attrib['protect']
         platform_define[name] = define
 
-    for ext_elem in xml.findall('./extensions/extension[@supported="vulkan"]'):
-        define = None
-        if "platform" in ext_elem.attrib:
-            define = platform_define[ext_elem.attrib['platform']]
-        extension = ext_factory(ext_elem.attrib['name'],
-                                number=int(ext_elem.attrib['number']),
-                                define=define)
+    for ext_elem in xml.findall('./extensions/extension'):
+        ext = Extension.from_xml(ext_elem)
+        if api not in ext.supported:
+            continue
 
-        for value in ext_elem.findall('./require/enum[@extends]'):
-            extends = value.attrib['extends']
-            enum = enum_factory.get(extends)
-            if enum is not None:
-                enum.add_value_from_xml(value, extension)
-            enum = bitmask_factory.get(extends)
-            if enum is not None:
-                enum.add_value_from_xml(value, extension)
-        for t in ext_elem.findall('./require/type'):
-            struct = struct_factory.get(t.attrib['name'])
-            if struct is not None:
-                struct.extension = extension
+        define = platform_define.get(ext.platform, None)
+        extension = ext_factory(ext.name, number=ext.number, define=define)
+
+        for req_elem in ext_elem.findall('./require'):
+            if not filter_api(req_elem, api):
+                continue
+
+            for value in req_elem.findall('./enum[@extends]'):
+                extends = value.attrib['extends']
+                enum = enum_factory.get(extends)
+                if enum is not None:
+                    enum.add_value_from_xml(value, extension)
+                enum = bitmask_factory.get(extends)
+                if enum is not None:
+                    enum.add_value_from_xml(value, extension)
+
+            for t in req_elem.findall('./type'):
+                struct = struct_factory.get(t.attrib['name'])
+                if struct is not None:
+                    struct.extension = extension
 
         if define:
             for value in ext_elem.findall('./require/type[@name]'):
@@ -431,12 +466,20 @@ def parse_xml(enum_factory, ext_factory, struct_factory, bitmask_factory,
                 if enum is not None:
                     enum.set_guard(define)
 
+    obj_type_enum = enum_factory.get("VkObjectType")
     obj_types = obj_type_factory("VkObjectType")
     for object_type in xml.findall('./types/type[@category="handle"]'):
         for object_name in object_type.findall('./name'):
             # Convert to int to avoid undefined enums
             enum = object_type.attrib['objtypeenum']
-            enum_val = enum_factory.get("VkObjectType").name_to_value[enum]
+
+            # Annoyingly, object types are hard to filter by API so just
+            # look for whether or not we can find the enum name in the
+            # VkObjectType enum.
+            if enum not in obj_type_enum.name_to_value:
+                continue
+
+            enum_val = obj_type_enum.name_to_value[enum]
             obj_types.enum_to_name[enum_val] = object_name.text
 
 
index 07e0763..262c6f8 100644 (file)
@@ -12,22 +12,24 @@ def get_api_list(s):
     return apis
 
 class Extension:
-    def __init__(self, name, ext_version):
+    def __init__(self, name, number, ext_version):
         self.name = name
         self.type = None
+        self.number = number
         self.platform = None
         self.ext_version = int(ext_version)
         self.supported = []
 
     def from_xml(ext_elem):
         name = ext_elem.attrib['name']
+        number = int(ext_elem.attrib['number'])
         supported = get_api_list(ext_elem.attrib['supported'])
         if name == 'VK_ANDROID_native_buffer':
             assert not supported
             supported = ['vulkan']
 
         if not supported:
-            return Extension(name, 0)
+            return Extension(name, number, 0)
 
         version = None
         for enum_elem in ext_elem.findall('.require/enum'):
@@ -38,7 +40,7 @@ class Extension:
                     version = int(enum_elem.attrib['value'])
 
         assert version is not None
-        ext = Extension(name, version)
+        ext = Extension(name, number, version)
         ext.type = ext_elem.attrib['type']
         ext.platform = ext_elem.attrib.get('platform', None)
         ext.supported = supported