NIR_PASS_V(nir, nir_opt_peephole_select, 64, false, true);
NIR_PASS_V(nir, nir_opt_algebraic_late);
+
+ /* Fuse add/sub/multiplies/shifts after running opt_algebraic_late to fuse
+ * isub but before shifts are lowered.
+ */
+ do {
+ progress = false;
+
+ NIR_PASS(progress, nir, nir_opt_dce);
+ NIR_PASS(progress, nir, nir_opt_cse);
+ NIR_PASS(progress, nir, agx_nir_fuse_algebraic_late);
+ } while (progress);
+
+ /* Do remaining lowering late, since this inserts &s for shifts so we want to
+ * do it after fusing constant shifts. Constant folding will clean up.
+ */
NIR_PASS_V(nir, agx_nir_lower_algebraic_late);
NIR_PASS_V(nir, nir_opt_constant_folding);
NIR_PASS_V(nir, nir_opt_combine_barriers, combine_all_barriers, NULL);
a = 'a'
b = 'b'
c = 'c'
+d = 'd'
+e = 'e'
lower_sm5_shift = []
(('i2f32', ('i2i32', a)), ('i2f32', a)),
]
+# (x * y) + s = (x * y) + (s << 0)
+def imad(x, y, z):
+ return ('imadshl_agx', x, y, z, 0)
+
+# (x * y) - s = (x * y) - (s << 0)
+def imsub(x, y, z):
+ return ('imsubshl_agx', x, y, z, 0)
+
+# x + (y << s) = (x * 1) + (y << s)
+def iaddshl(x, y, s):
+ return ('imadshl_agx', x, 1, y, s)
+
+# x - (y << s) = (x * 1) - (y << s)
+def isubshl(x, y, s):
+ return ('imsubshl_agx', x, 1, y, s)
+
+fuse_imad = [
+ # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up
+ # in compute shader lowering.
+ (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b),
+ ('imul(is_used_once)', c, d)), e),
+ imad(a, b, imad(c, d, e))),
+
+ # Fuse regular imad
+ (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)),
+ (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)),
+]
+
+for s in range(1, 5):
+ fuse_imad += [
+ # Definitions
+ (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)),
+ (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)),
+
+ # ineg(x) is 0 - x
+ (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)),
+
+ # Definitions
+ (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)),
+ (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)),
+
+ # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s))
+ (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)),
+
+ # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s))
+ (('imul', a, 1 - (1 << s)), isubshl(a, a, s)),
+
+ # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1)
+ (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)),
+
+ # iadd is SCIB, general shfit is IC (slower)
+ (('ishl', a, s), iaddshl(0, a, s)),
+ ]
+
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--import-path', required=True)
print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
lower_sm5_shift + lower_pack).render())
+ print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late",
+ fuse_imad).render())
if __name__ == '__main__':