From 3c7634f7d27c76eee0f11a3c1c2541a648bfdc6f Mon Sep 17 00:00:00 2001 From: Boris Brezillon Date: Tue, 2 Mar 2021 13:08:05 +0100 Subject: [PATCH] pan/bi: Extend the bi_builder to support type variants correctly Some opcodes come with both type and size variants. Right now, only the size is taken into account. Extend the builder to provide wrappers that take a nir_type in addition to the bitsize. While at it, fix wrappers taking a compare operator to use the proper .{i,s,u} variant based on the comparison (equal and non-equal should use .i, other comparisons should use .{u,s}). Signed-off-by: Boris Brezillon Reviewed-by: Alyssa Rosenzweig Part-of: --- src/panfrost/bifrost/bi_builder.h.py | 74 ++++++++++++++++++++++++++++++---- src/panfrost/bifrost/bifrost_compile.c | 14 +++---- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/src/panfrost/bifrost/bi_builder.h.py b/src/panfrost/bifrost/bi_builder.h.py index 4791a1a..903ef4e 100644 --- a/src/panfrost/bifrost/bi_builder.h.py +++ b/src/panfrost/bifrost/bi_builder.h.py @@ -28,6 +28,31 @@ TEMPLATE = """ #include "compiler.h" <% +def nirtypes(opcode): + split = opcode.split('.', 1) + if len(split) < 2: + split = opcode.split('_') + + if len(split) <= 1: + return None + + assert len(split) > 1 + + type = split[1] + if type[0] == 'v': + type = type[2:] + + if type[0] == 'f': + return ['nir_type_float'] + elif type[0] == 's': + return ['nir_type_int'] + elif type[0] == 'u': + return ['nir_type_uint'] + elif type[0] == 'i': + return ['nir_type_uint', 'nir_type_int'] + else: + return None + def typesize(opcode): if opcode[-3:] == '128': return 128 @@ -41,6 +66,29 @@ def typesize(opcode): except: return None +def condition(opcode, typecheck, sizecheck): + cond = '' + if typecheck == True: + cond += '(' + types = nirtypes(opcode) + assert types != None + for T in types: + cond += "{}type == {}".format(' || ' if cond[-1] != '(' else '', T) + cond += ')' + + if sizecheck == True: + cond += "{}bitsize == {}".format(' && ' if cond != '' else '', typesize(opcode)) + + cmpf_mods = ops[opcode]["modifiers"]["cmpf"] if "cmpf" in ops[opcode]["modifiers"] else None + if "cmpf" in ops[opcode]["modifiers"]: + cond += "{}(".format(' && ' if cond != '' else '') + for cmpf in ops[opcode]["modifiers"]["cmpf"]: + if cmpf != 'reserved': + cond += "{}cmpf == BI_CMPF_{}".format(' || ' if cond[-1] != '(' else '', cmpf.upper()) + cond += ')' + + return 'true' if cond == '' else cond + def to_suffix(op): return "_to" if op["dests"] > 0 else "" @@ -81,23 +129,34 @@ bi_index bi_${opcode.replace('.', '_').lower()}(${signature(ops[opcode], modifie <% common_op = opcode.split('.')[0] variants = [a for a in ops.keys() if a.split('.')[0] == common_op] - signatures = [signature(ops[op], modifiers, sized=True, no_dests=True) for op in variants] + signatures = [signature(ops[op], modifiers, no_dests=True) for op in variants] homogenous = all([sig == signatures[0] for sig in signatures]) + types = [nirtypes(x) for x in variants] + typeful = False + for t in types: + if t != types[0]: + typeful = True + sizes = [typesize(x) for x in variants] + sized = False + for size in sizes: + if size != sizes[0]: + sized = True + last = opcode == variants[-1] %> % if homogenous and len(variants) > 1 and last: % for (suffix, temp, dests, ret) in (('_to', False, 1, 'instr *'), ('', True, 0, 'index')): % if not temp or ops[opcode]["dests"] > 0: static inline -bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, sized=True, no_dests=not dests)}) +bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, typeful=typeful, sized=sized, no_dests=not dests)}) { -% for i, (variant, size) in enumerate(zip(variants, sizes)): - ${"else " if i > 0 else ""} if (bitsize == ${size}) +% for i, variant in enumerate(variants): + ${"{}if ({})".format("else " if i > 0 else "", condition(variant, typeful, sized))} return (bi_${variant.replace('.', '_').lower()}${to_suffix(ops[opcode])}(${arguments(ops[opcode], temp_dest = temp)}))${"->dest[0]" if temp else ""}; % endfor else - unreachable("Invalid bitsize for ${common_op}"); + unreachable("Invalid parameters for ${common_op}"); } %endif @@ -122,10 +181,11 @@ def should_skip(mod): def modifier_signature(op): return sorted([m for m in op["modifiers"].keys() if not should_skip(m)]) -def signature(op, modifiers, sized = False, no_dests = False): +def signature(op, modifiers, typeful = False, sized = False, no_dests = False): return ", ".join( ["bi_builder *b"] + - (["unsigned bitsize"] if sized else []) + + (["nir_alu_type type"] if typeful == True else []) + + (["unsigned bitsize"] if sized == True else []) + ["bi_index dest{}".format(i) for i in range(0 if no_dests else op["dests"])] + ["bi_index src{}".format(i) for i in range(src_count(op))] + ["{} {}".format( diff --git a/src/panfrost/bifrost/bifrost_compile.c b/src/panfrost/bifrost/bifrost_compile.c index 2201522..727dd304 100644 --- a/src/panfrost/bifrost/bifrost_compile.c +++ b/src/panfrost/bifrost/bifrost_compile.c @@ -1670,7 +1670,7 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr) if (sz == 8) bi_mux_v4i8_to(b, dst, s2, s1, s0, BI_MUX_INT_ZERO); else - bi_csel_to(b, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE); + bi_csel_to(b, nir_type_float, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE); break; case nir_op_ishl: @@ -1890,27 +1890,27 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr) break; case nir_op_iadd: - bi_iadd_to(b, sz, dst, s0, s1, false); + bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, false); break; case nir_op_iadd_sat: - bi_iadd_to(b, sz, dst, s0, s1, true); + bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, true); break; case nir_op_ihadd: - bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTN); + bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTN); break; case nir_op_irhadd: - bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTP); + bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTP); break; case nir_op_isub: - bi_isub_to(b, sz, dst, s0, s1, false); + bi_isub_to(b, nir_type_int, sz, dst, s0, s1, false); break; case nir_op_isub_sat: - bi_isub_to(b, sz, dst, s0, s1, true); + bi_isub_to(b, nir_type_int, sz, dst, s0, s1, true); break; case nir_op_imul: -- 2.7.4