nir/algebraic: collapse ALU opcodes sourcing NaN
authorMarek Olšák <marek.olsak@amd.com>
Sat, 8 Jul 2023 21:09:15 +0000 (17:09 -0400)
committerMarek Olšák <marek.olsak@amd.com>
Sat, 19 Aug 2023 18:18:52 +0000 (14:18 -0400)
Undef will be replaced by NaN whenever it leads to elimination of FP
instructions. This implements the elimination part.

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24792>

src/compiler/nir/nir_opt_algebraic.py
src/compiler/nir/nir_search_helpers.h

index a430aa1..030ac6e 100644 (file)
@@ -35,6 +35,7 @@ b = 'b'
 c = 'c'
 d = 'd'
 e = 'e'
+NAN = math.nan
 
 signed_zero_inf_nan_preserve_16 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 16)'
 signed_zero_inf_nan_preserve_32 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 32)'
@@ -2757,6 +2758,48 @@ for s in range(0, 31):
          'options->avoid_ternary_with_two_constants'),
     ])
 
+# NaN propagation: Binary opcodes. If any operand is NaN, replace it with NaN.
+# (unary opcodes with NaN are evaluated by nir_opt_constant_folding, not here)
+for op in ['fadd', 'fdiv', 'fmod', 'fmul', 'fpow', 'frem', 'fsub']:
+    optimizations += [((op, '#a(is_nan)', b), NAN)]
+    optimizations += [((op, a, '#b(is_nan)'), NAN)] # some opcodes are not commutative
+
+# NaN propagation: Trinary opcodes. If any operand is NaN, replace it with NaN.
+for op in ['ffma', 'flrp']:
+    optimizations += [((op, '#a(is_nan)', b, c), NAN)]
+    optimizations += [((op, a, '#b(is_nan)', c), NAN)] # some opcodes are not commutative
+    optimizations += [((op, a, b, '#c(is_nan)'), NAN)]
+
+# NaN propagation: FP min/max. Pick the non-NaN operand.
+for op in ['fmin', 'fmax']:
+    optimizations += [((op, '#a(is_nan)', b), b)] # commutative
+
+# NaN propagation: ldexp is NaN if the first operand is NaN.
+optimizations += [(('ldexp', '#a(is_nan)', b), NAN)]
+
+# NaN propagation: Dot opcodes. If any component is NaN, replace it with NaN.
+for op in ['fdot2', 'fdot3', 'fdot4', 'fdot5', 'fdot8', 'fdot16']:
+    optimizations += [((op, '#a(is_any_comp_nan)', b), NAN)] # commutative
+
+# NaN propagation: FP comparison opcodes except !=. Replace it with false.
+for op in ['feq', 'fge', 'flt']:
+    optimizations += [((op, '#a(is_nan)', b), False)]
+    optimizations += [((op, a, '#b(is_nan)'), False)] # some opcodes are not commutative
+
+# NaN propagation: FP comparison opcodes using !=. Replace it with true.
+# Operator != is the only opcode where a comparison with NaN returns true.
+for op in ['fneu']:
+    optimizations += [((op, '#a(is_nan)', b), True)] # commutative
+
+# NaN propagation: FP comparison opcodes except != returning FP 0 or 1.
+for op in ['seq', 'sge', 'slt']:
+    optimizations += [((op, '#a(is_nan)', b), 0.0)]
+    optimizations += [((op, a, '#b(is_nan)'), 0.0)] # some opcodes are not commutative
+
+# NaN propagation: FP comparison opcodes using != returning FP 0 or 1.
+# Operator != is the only opcode where a comparison with NaN returns true.
+optimizations += [(('sne', '#a(is_nan)', b), 1.0)] # commutative
+
 # This section contains optimizations to propagate downsizing conversions of
 # constructed vectors into vectors of downsized components. Whether this is
 # useful depends on the SIMD semantics of the backend. On a true SIMD machine,
index 32c1f88..809c257 100644 (file)
@@ -111,6 +111,38 @@ is_bitcount2(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
    return true;
 }
 
+static inline bool
+is_nan(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
+       unsigned src, unsigned num_components, const uint8_t *swizzle)
+{
+   /* only constant srcs: */
+   if (!nir_src_is_const(instr->src[src].src))
+      return false;
+
+   for (unsigned i = 0; i < num_components; i++) {
+      if (!isnan(nir_src_comp_as_float(instr->src[src].src, swizzle[i])))
+         return false;
+   }
+
+   return true;
+}
+
+static inline bool
+is_any_comp_nan(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
+                unsigned src, unsigned num_components, const uint8_t *swizzle)
+{
+   /* only constant srcs: */
+   if (!nir_src_is_const(instr->src[src].src))
+      return false;
+
+   for (unsigned i = 0; i < num_components; i++) {
+      if (isnan(nir_src_comp_as_float(instr->src[src].src, swizzle[i])))
+         return true;
+   }
+
+   return false;
+}
+
 #define MULTIPLE(test)                                                         \
    static inline bool                                                          \
       is_unsigned_multiple_of_##test(UNUSED struct hash_table *ht,             \