freedreno/registers: Add c++ magic for register variants
authorRob Clark <robdclark@chromium.org>
Thu, 2 Mar 2023 19:18:05 +0000 (11:18 -0800)
committerMarge Bot <emma+marge@anholt.net>
Mon, 13 Mar 2023 17:31:24 +0000 (17:31 +0000)
For regs with multiple variants, generate a template'ized function to
pack the reg value.  If the template param is known at compile time
(which is the expected usage) this will optimize to the same thing as
the "traditional" reg packing.

Signed-off-by: Rob Clark <robdclark@chromium.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21846>

src/freedreno/registers/gen_header.py

index ef194ca..354f380 100644 (file)
@@ -128,7 +128,15 @@ class Bitset(object):
                else:
                        self.fields = []
 
-       def dump_regpair_builder(self, prefix, array, bit_size, address):
+       # Get address field if there is one in the bitset, else return None:
+       def get_address_field(self):
+               for f in self.fields:
+                       if f.type in [ "address", "waddress" ]:
+                               return f
+               return None
+
+       def dump_regpair_builder(self, reg):
+               prefix = reg.full_name
                print("#ifndef NDEBUG")
                known_mask = 0
                for f in self.fields:
@@ -141,8 +149,8 @@ class Bitset(object):
                print("#endif\n")
 
                print("    return (struct fd_reg_pair) {")
-               if array:
-                       print("        .reg = REG_%s(i)," % prefix)
+               if reg.array:
+                       print("        .reg = REG_%s(__i)," % prefix)
                else:
                        print("        .reg = REG_%s," % prefix)
 
@@ -154,10 +162,11 @@ class Bitset(object):
                                type, val = f.ctype("fields.%s" % field_name(prefix, f))
                                print("            (%-40s << %2d) |" % (val, f.low))
                value_name = "dword"
-               if bit_size == 64:
+               if reg.bit_size == 64:
                        value_name = "qword"
                print("            fields.unknown | fields.%s," % (value_name,))
 
+               address = self.get_address_field()
                if address:
                        print("        .bo = fields.bo,")
                        print("        .is_address = true,")
@@ -169,12 +178,11 @@ class Bitset(object):
 
                print("    };")
 
-       def dump_pack_struct(self, prefix=None, array=None, bit_size=32):
-
-               if not prefix:
+       def dump_pack_struct(self, reg=None):
+               if not reg:
                        return
-               if prefix == None:
-                       prefix = self.name
+
+               prefix = reg.full_name
 
                print("struct %s {" % prefix)
                for f in self.fields:
@@ -187,7 +195,7 @@ class Bitset(object):
                        type, val = f.ctype("var")
 
                        tab_to("    %s" % type, "%s;" % name)
-               if bit_size == 64:
+               if reg.bit_size == 64:
                        tab_to("    uint64_t", "unknown;")
                        tab_to("    uint64_t", "qword;")
                else:
@@ -195,28 +203,24 @@ class Bitset(object):
                        tab_to("    uint32_t", "dword;")
                print("};\n")
 
-               address = None;
-               for f in self.fields:
-                       if f.type in [ "address", "waddress" ]:
-                               address = f
-               if array:
-                       print("static inline struct fd_reg_pair\npack_%s(uint32_t i, struct %s fields)\n{" %
+               if reg.array:
+                       print("static inline struct fd_reg_pair\npack_%s(uint32_t __i, struct %s fields)\n{" %
                                  (prefix, prefix));
                else:
                        print("static inline struct fd_reg_pair\npack_%s(struct %s fields)\n{" %
                                  (prefix, prefix));
 
-               self.dump_regpair_builder(prefix, array, bit_size, address)
+               self.dump_regpair_builder(reg)
 
                print("\n}\n")
 
-               if address:
+               if self.get_address_field():
                        skip = ", { .reg = 0 }"
                else:
                        skip = ""
 
-               if array:
-                       print("#define %s(i, ...) pack_%s(i, (struct %s) { __VA_ARGS__ })%s\n" %
+               if reg.array:
+                       print("#define %s(__i, ...) pack_%s(__i, (struct %s) { __VA_ARGS__ })%s\n" %
                                  (prefix, prefix, prefix, skip))
                else:
                        print("#define %s(...) pack_%s((struct %s) { __VA_ARGS__ })%s\n" %
@@ -264,6 +268,9 @@ class Array(object):
        def dump_pack_struct(self):
                pass
 
+       def dump_regpair_builder(self):
+               pass
+
 class Reg(object):
        def __init__(self, attrs, domain, array, bit_size):
                self.name = attrs["name"]
@@ -272,11 +279,9 @@ class Reg(object):
                self.offset = int(attrs["offset"], 0)
                self.type = None
                self.bit_size = bit_size
-
-               if self.array:
-                       self.full_name = self.domain + "_" + self.array.name + "_" + self.name
-               else:
-                       self.full_name = self.domain + "_" + self.name
+               if array:
+                       self.name = array.name + "_" + self.name
+               self.full_name = self.domain + "_" + self.name
 
        def dump(self):
                if self.array:
@@ -291,8 +296,11 @@ class Reg(object):
 
        def dump_pack_struct(self):
                if self.bitset.inline:
-                       self.bitset.dump_pack_struct(self.full_name, not self.array == None, self.bit_size)
+                       self.bitset.dump_pack_struct(self)
 
+       def dump_regpair_builder(self):
+               if self.bitset.inline:
+                       self.bitset.dump_regpair_builder(self)
 
 class Parser(object):
        def __init__(self):
@@ -306,6 +314,9 @@ class Parser(object):
                # The varset attribute on the domain specifies the enum which
                # specifies all possible hw variants:
                self.current_varset = None
+               # Regs that have multiple variants.. we only generated the C++
+               # template based struct-packers for these
+               self.variant_regs = {}
                self.bitsets = {}
                self.enums = {}
                self.file = []
@@ -374,6 +385,21 @@ class Parser(object):
 
                return variant
 
+       def add_all_variants(self, reg, attrs):
+               # TODO this should really handle *all* variants, including dealing
+               # with open ended ranges (ie. "A2XX,A4XX-") (we have the varset
+               # enum now to make that possible)
+               variant = self.parse_variants(attrs)
+
+               if reg.name not in self.variant_regs:
+                       self.variant_regs[reg.name] = {}
+               else:
+                       # All variants must be same size:
+                       v = next(iter(self.variant_regs[reg.name]))
+                       assert self.variant_regs[reg.name][v].bit_size == reg.bit_size
+
+               self.variant_regs[reg.name][variant] = reg;
+
        def do_validate(self, schemafile):
                try:
                        from lxml import etree
@@ -443,6 +469,9 @@ class Parser(object):
                if len(self.stack) == 1:
                        self.file.append(self.current_reg)
 
+               if variant is not None:
+                       self.add_all_variants(self.current_reg, attrs)
+
        def start_element(self, name, attrs):
                if name == "import":
                        filename = attrs["file"]
@@ -526,10 +555,81 @@ class Parser(object):
                for e in enums + bitsets + regs:
                        e.dump()
 
+       def dump_reg_variants(self, regname, variants):
+               # Don't bother for things that only have a single variant:
+               if len(variants) == 1:
+                       return
+               print("#ifdef __cplusplus");
+               print("struct __%s {" % regname)
+               # TODO be more clever.. we should probably figure out which
+               # fields have the same type in all variants (in which they
+               # appear) and stuff everything else in a variant specific
+               # sub-structure.
+               seen_fields = []
+               bit_size = 32
+               array = False
+               address = None
+               for variant in variants.keys():
+                       print("    /* %s fields: */" % variant)
+                       reg = variants[variant]
+                       bit_size = reg.bit_size
+                       array = reg.array
+                       for f in reg.bitset.fields:
+                               fld_name = field_name(reg.full_name, f)
+                               if fld_name in seen_fields:
+                                       continue
+                               seen_fields.append(fld_name)
+                               name = fld_name.lower()
+                               if f.type in [ "address", "waddress" ]:
+                                       if address:
+                                               continue
+                                       address = f
+                                       tab_to("    __bo_type", "bo;")
+                                       tab_to("    uint32_t", "bo_offset;")
+                                       continue
+                               type, val = f.ctype("var")
+                               tab_to("    %s" %type, "%s;" %name)
+               print("    /* fallback fields: */")
+               if bit_size == 64:
+                       tab_to("    uint64_t", "unknown;")
+                       tab_to("    uint64_t", "qword;")
+               else:
+                       tab_to("    uint32_t", "unknown;")
+                       tab_to("    uint32_t", "dword;")
+               print("};")
+               # TODO don't hardcode the varset enum name
+               varenum = "chip"
+               print("template <%s %s>" % (varenum, varenum.upper()))
+               print("static inline struct fd_reg_pair")
+               xtra = ""
+               xtravar = ""
+               if array:
+                       xtra = "int __i, "
+                       xtravar = "__i, "
+               print("__%s(%sstruct __%s fields) {" % (regname, xtra, regname))
+               for variant in variants.keys():
+                       print("  if (%s == %s) {" % (varenum.upper(), variant))
+                       reg = variants[variant]
+                       reg.dump_regpair_builder()
+                       print("  } else")
+               print("    assert(!\"invalid variant\");")
+               print("}")
+
+               if bit_size == 64:
+                       skip = ", { .reg = 0 }"
+               else:
+                       skip = ""
+
+               print("#define %s(VARIANT, %s...) __%s<VARIANT>(%s{__VA_ARGS__})%s" % (regname, xtravar, regname, xtravar, skip))
+               print("#endif /* __cplusplus */")
+
        def dump_structs(self):
                for e in self.file:
                        e.dump_pack_struct()
 
+               for regname in self.variant_regs:
+                       self.dump_reg_variants(regname, self.variant_regs[regname])
+
 
 def main():
        p = Parser()