pan/bi: Extend the bi_builder to support type variants correctly
authorBoris Brezillon <boris.brezillon@collabora.com>
Tue, 2 Mar 2021 12:08:05 +0000 (13:08 +0100)
committerMarge Bot <eric+marge@anholt.net>
Thu, 11 Mar 2021 14:30:19 +0000 (14:30 +0000)
Some opcodes come with both type and size variants. Right now, only the
size is taken into account. Extend the builder to provide wrappers that
take a nir_type in addition to the bitsize.

While at it, fix wrappers taking a compare operator to use the proper
.{i,s,u} variant based on the comparison (equal and non-equal should
use .i, other comparisons should use .{u,s}).

Signed-off-by: Boris Brezillon <boris.brezillon@collabora.com>
Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9520>

src/panfrost/bifrost/bi_builder.h.py
src/panfrost/bifrost/bifrost_compile.c

index 4791a1a..903ef4e 100644 (file)
@@ -28,6 +28,31 @@ TEMPLATE = """
 #include "compiler.h"
 
 <%
+def nirtypes(opcode):
+    split = opcode.split('.', 1)
+    if len(split) < 2:
+        split = opcode.split('_')
+
+    if len(split) <= 1:
+        return None
+
+    assert len(split) > 1
+
+    type = split[1]
+    if type[0] == 'v':
+        type = type[2:]
+
+    if type[0] == 'f':
+        return ['nir_type_float']
+    elif type[0] == 's':
+        return ['nir_type_int']
+    elif type[0] == 'u':
+        return ['nir_type_uint']
+    elif type[0] == 'i':
+        return ['nir_type_uint', 'nir_type_int']
+    else:
+        return None
+
 def typesize(opcode):
     if opcode[-3:] == '128':
         return 128
@@ -41,6 +66,29 @@ def typesize(opcode):
         except:
             return None
 
+def condition(opcode, typecheck, sizecheck):
+    cond = ''
+    if typecheck == True:
+        cond += '('
+        types = nirtypes(opcode)
+        assert types != None
+        for T in types:
+            cond += "{}type == {}".format(' || ' if cond[-1] != '(' else '', T)
+        cond += ')'
+
+    if sizecheck == True:
+        cond += "{}bitsize == {}".format(' && ' if cond != '' else '', typesize(opcode))
+
+    cmpf_mods = ops[opcode]["modifiers"]["cmpf"] if "cmpf" in ops[opcode]["modifiers"] else None
+    if "cmpf" in ops[opcode]["modifiers"]:
+        cond += "{}(".format(' && ' if cond != '' else '')
+        for cmpf in ops[opcode]["modifiers"]["cmpf"]:
+            if cmpf != 'reserved':
+                cond += "{}cmpf == BI_CMPF_{}".format(' || ' if cond[-1] != '(' else '', cmpf.upper())
+        cond += ')'
+
+    return 'true' if cond == '' else cond
+
 def to_suffix(op):
     return "_to" if op["dests"] > 0 else ""
 
@@ -81,23 +129,34 @@ bi_index bi_${opcode.replace('.', '_').lower()}(${signature(ops[opcode], modifie
 <%
     common_op = opcode.split('.')[0]
     variants = [a for a in ops.keys() if a.split('.')[0] == common_op]
-    signatures = [signature(ops[op], modifiers, sized=True, no_dests=True) for op in variants]
+    signatures = [signature(ops[op], modifiers, no_dests=True) for op in variants]
     homogenous = all([sig == signatures[0] for sig in signatures])
+    types = [nirtypes(x) for x in variants]
+    typeful = False
+    for t in types:
+        if t != types[0]:
+            typeful = True
+
     sizes = [typesize(x) for x in variants]
+    sized = False
+    for size in sizes:
+        if size != sizes[0]:
+            sized = True
+
     last = opcode == variants[-1]
 %>
 % if homogenous and len(variants) > 1 and last:
 % for (suffix, temp, dests, ret) in (('_to', False, 1, 'instr *'), ('', True, 0, 'index')):
 % if not temp or ops[opcode]["dests"] > 0:
 static inline
-bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, sized=True, no_dests=not dests)})
+bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, typeful=typeful, sized=sized, no_dests=not dests)})
 {
-% for i, (variant, size) in enumerate(zip(variants, sizes)):
-    ${"else " if i > 0 else ""} if (bitsize == ${size})
+% for i, variant in enumerate(variants):
+    ${"{}if ({})".format("else " if i > 0 else "", condition(variant, typeful, sized))}
         return (bi_${variant.replace('.', '_').lower()}${to_suffix(ops[opcode])}(${arguments(ops[opcode], temp_dest = temp)}))${"->dest[0]" if temp else ""};
 % endfor
     else
-        unreachable("Invalid bitsize for ${common_op}");
+        unreachable("Invalid parameters for ${common_op}");
 }
 
 %endif
@@ -122,10 +181,11 @@ def should_skip(mod):
 def modifier_signature(op):
     return sorted([m for m in op["modifiers"].keys() if not should_skip(m)])
 
-def signature(op, modifiers, sized = False, no_dests = False):
+def signature(op, modifiers, typeful = False, sized = False, no_dests = False):
     return ", ".join(
         ["bi_builder *b"] +
-        (["unsigned bitsize"] if sized else []) +
+        (["nir_alu_type type"] if typeful == True else []) +
+        (["unsigned bitsize"] if sized == True else []) +
         ["bi_index dest{}".format(i) for i in range(0 if no_dests else op["dests"])] +
         ["bi_index src{}".format(i) for i in range(src_count(op))] +
         ["{} {}".format(
index 2201522..727dd30 100644 (file)
@@ -1670,7 +1670,7 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr)
                 if (sz == 8)
                         bi_mux_v4i8_to(b, dst, s2, s1, s0, BI_MUX_INT_ZERO);
                 else
-                        bi_csel_to(b, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE);
+                        bi_csel_to(b, nir_type_float, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE);
                 break;
 
         case nir_op_ishl:
@@ -1890,27 +1890,27 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr)
                 break;
 
         case nir_op_iadd:
-                bi_iadd_to(b, sz, dst, s0, s1, false);
+                bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, false);
                 break;
 
         case nir_op_iadd_sat:
-                bi_iadd_to(b, sz, dst, s0, s1, true);
+                bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, true);
                 break;
 
         case nir_op_ihadd:
-                bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTN);
+                bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTN);
                 break;
 
         case nir_op_irhadd:
-                bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTP);
+                bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTP);
                 break;
 
         case nir_op_isub:
-                bi_isub_to(b, sz, dst, s0, s1, false);
+                bi_isub_to(b, nir_type_int, sz, dst, s0, s1, false);
                 break;
 
         case nir_op_isub_sat:
-                bi_isub_to(b, sz, dst, s0, s1, true);
+                bi_isub_to(b, nir_type_int, sz, dst, s0, s1, true);
                 break;
 
         case nir_op_imul: