From 5a80bf2eb0417c27ce942fb6949df6dab41236b9 Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Tue, 25 Apr 2023 13:52:32 -0400 Subject: [PATCH] agx: Optimize multiplies We have an imad instruction and our iadd has a small immediate shift on the second source. Together, these allow expressing lots of integer multiplies more efficiently. Add some rules to optimize these now that the backend compiler can ingest the optimized forms. Half-register changes are from load_const scheduling changing in some vertex shaders. total instructions in shared programs: 1539092 -> 1537949 (-0.07%) instructions in affected programs: 167896 -> 166753 (-0.68%) total bytes in shared programs: 10543012 -> 10533866 (-0.09%) bytes in affected programs: 1218068 -> 1208922 (-0.75%) total halfregs in shared programs: 483180 -> 483448 (0.06%) halfregs in affected programs: 1942 -> 2210 (13.80%) Signed-off-by: Alyssa Rosenzweig Part-of: --- src/asahi/compiler/agx_compile.c | 15 +++++++++ src/asahi/compiler/agx_nir.h | 1 + src/asahi/compiler/agx_nir_algebraic.py | 58 +++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 04bff8a..6ed2b01 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -2027,6 +2027,21 @@ agx_optimize_nir(nir_shader *nir, unsigned *preamble_size) NIR_PASS_V(nir, nir_opt_peephole_select, 64, false, true); NIR_PASS_V(nir, nir_opt_algebraic_late); + + /* Fuse add/sub/multiplies/shifts after running opt_algebraic_late to fuse + * isub but before shifts are lowered. + */ + do { + progress = false; + + NIR_PASS(progress, nir, nir_opt_dce); + NIR_PASS(progress, nir, nir_opt_cse); + NIR_PASS(progress, nir, agx_nir_fuse_algebraic_late); + } while (progress); + + /* Do remaining lowering late, since this inserts &s for shifts so we want to + * do it after fusing constant shifts. Constant folding will clean up. + */ NIR_PASS_V(nir, agx_nir_lower_algebraic_late); NIR_PASS_V(nir, nir_opt_constant_folding); NIR_PASS_V(nir, nir_opt_combine_barriers, combine_all_barriers, NULL); diff --git a/src/asahi/compiler/agx_nir.h b/src/asahi/compiler/agx_nir.h index 6829116..8775f2c 100644 --- a/src/asahi/compiler/agx_nir.h +++ b/src/asahi/compiler/agx_nir.h @@ -10,5 +10,6 @@ struct nir_shader; bool agx_nir_lower_algebraic_late(struct nir_shader *shader); +bool agx_nir_fuse_algebraic_late(struct nir_shader *shader); #endif diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index deef09d..868e286 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -10,6 +10,8 @@ import math a = 'a' b = 'b' c = 'c' +d = 'd' +e = 'e' lower_sm5_shift = [] @@ -38,6 +40,60 @@ lower_pack = [ (('i2f32', ('i2i32', a)), ('i2f32', a)), ] +# (x * y) + s = (x * y) + (s << 0) +def imad(x, y, z): + return ('imadshl_agx', x, y, z, 0) + +# (x * y) - s = (x * y) - (s << 0) +def imsub(x, y, z): + return ('imsubshl_agx', x, y, z, 0) + +# x + (y << s) = (x * 1) + (y << s) +def iaddshl(x, y, s): + return ('imadshl_agx', x, 1, y, s) + +# x - (y << s) = (x * 1) - (y << s) +def isubshl(x, y, s): + return ('imsubshl_agx', x, 1, y, s) + +fuse_imad = [ + # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up + # in compute shader lowering. + (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b), + ('imul(is_used_once)', c, d)), e), + imad(a, b, imad(c, d, e))), + + # Fuse regular imad + (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)), + (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)), +] + +for s in range(1, 5): + fuse_imad += [ + # Definitions + (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)), + (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)), + + # ineg(x) is 0 - x + (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)), + + # Definitions + (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)), + (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)), + + # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s)) + (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)), + + # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s)) + (('imul', a, 1 - (1 << s)), isubshl(a, a, s)), + + # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1) + (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)), + + # iadd is SCIB, general shfit is IC (slower) + (('ishl', a, s), iaddshl(0, a, s)), + ] + def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--import-path', required=True) @@ -52,6 +108,8 @@ def run(): print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", lower_sm5_shift + lower_pack).render()) + print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late", + fuse_imad).render()) if __name__ == '__main__': -- 2.7.4