agx: Optimize multiplies
authorAlyssa Rosenzweig <alyssa@rosenzweig.io>
Tue, 25 Apr 2023 17:52:32 +0000 (13:52 -0400)
committerAlyssa Rosenzweig <alyssa@rosenzweig.io>
Thu, 11 May 2023 13:23:23 +0000 (09:23 -0400)
We have an imad instruction and our iadd has a small immediate shift on the
second source. Together, these allow expressing lots of integer multiplies more
efficiently. Add some rules to optimize these now that the backend compiler can
ingest the optimized forms.

Half-register changes are from load_const scheduling changing in some vertex
shaders.

   total instructions in shared programs: 1539092 -> 1537949 (-0.07%)
   instructions in affected programs: 167896 -> 166753 (-0.68%)

   total bytes in shared programs: 10543012 -> 10533866 (-0.09%)
   bytes in affected programs: 1218068 -> 1208922 (-0.75%)

   total halfregs in shared programs: 483180 -> 483448 (0.06%)
   halfregs in affected programs: 1942 -> 2210 (13.80%)

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22695>

src/asahi/compiler/agx_compile.c
src/asahi/compiler/agx_nir.h
src/asahi/compiler/agx_nir_algebraic.py

index 04bff8a..6ed2b01 100644 (file)
@@ -2027,6 +2027,21 @@ agx_optimize_nir(nir_shader *nir, unsigned *preamble_size)
    NIR_PASS_V(nir, nir_opt_peephole_select, 64, false, true);
 
    NIR_PASS_V(nir, nir_opt_algebraic_late);
+
+   /* Fuse add/sub/multiplies/shifts after running opt_algebraic_late to fuse
+    * isub but before shifts are lowered.
+    */
+   do {
+      progress = false;
+
+      NIR_PASS(progress, nir, nir_opt_dce);
+      NIR_PASS(progress, nir, nir_opt_cse);
+      NIR_PASS(progress, nir, agx_nir_fuse_algebraic_late);
+   } while (progress);
+
+   /* Do remaining lowering late, since this inserts &s for shifts so we want to
+    * do it after fusing constant shifts. Constant folding will clean up.
+    */
    NIR_PASS_V(nir, agx_nir_lower_algebraic_late);
    NIR_PASS_V(nir, nir_opt_constant_folding);
    NIR_PASS_V(nir, nir_opt_combine_barriers, combine_all_barriers, NULL);
index 6829116..8775f2c 100644 (file)
@@ -10,5 +10,6 @@
 struct nir_shader;
 
 bool agx_nir_lower_algebraic_late(struct nir_shader *shader);
+bool agx_nir_fuse_algebraic_late(struct nir_shader *shader);
 
 #endif
index deef09d..868e286 100644 (file)
@@ -10,6 +10,8 @@ import math
 a = 'a'
 b = 'b'
 c = 'c'
+d = 'd'
+e = 'e'
 
 lower_sm5_shift = []
 
@@ -38,6 +40,60 @@ lower_pack = [
     (('i2f32', ('i2i32', a)), ('i2f32', a)),
 ]
 
+# (x * y) + s = (x * y) + (s << 0)
+def imad(x, y, z):
+    return ('imadshl_agx', x, y, z, 0)
+
+# (x * y) - s = (x * y) - (s << 0)
+def imsub(x, y, z):
+    return ('imsubshl_agx', x, y, z, 0)
+
+# x + (y << s) = (x * 1) + (y << s)
+def iaddshl(x, y, s):
+    return ('imadshl_agx', x, 1, y, s)
+
+# x - (y << s) = (x * 1) - (y << s)
+def isubshl(x, y, s):
+    return ('imsubshl_agx', x, 1, y, s)
+
+fuse_imad = [
+    # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up
+    # in compute shader lowering.
+    (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b),
+              ('imul(is_used_once)', c, d)), e),
+     imad(a, b, imad(c, d, e))),
+
+    # Fuse regular imad
+    (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)),
+    (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)),
+]
+
+for s in range(1, 5):
+    fuse_imad += [
+        # Definitions
+        (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)),
+        (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)),
+
+        # ineg(x) is 0 - x
+        (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)),
+
+        # Definitions
+        (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)),
+        (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)),
+
+        # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s))
+        (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)),
+
+        # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s))
+        (('imul', a, 1 - (1 << s)), isubshl(a, a, s)),
+
+        # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1)
+        (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)),
+
+        # iadd is SCIB, general shfit is IC (slower)
+        (('ishl', a, s), iaddshl(0, a, s)),
+    ]
+
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument('-p', '--import-path', required=True)
@@ -52,6 +108,8 @@ def run():
 
     print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
                                       lower_sm5_shift + lower_pack).render())
+    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late",
+                                      fuse_imad).render())
 
 
 if __name__ == '__main__':