nir: make fdph lowering match fdot
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 20 Jan 2023 11:38:43 +0000 (11:38 +0000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 8 Mar 2023 14:38:26 +0000 (14:38 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20812>

src/compiler/nir/nir_lower_alu_width.c

index 05a4871..a2e519e 100644 (file)
@@ -294,15 +294,29 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
       nir_ssa_def *src0_vec = nir_ssa_for_alu_src(b, alu, 0);
       nir_ssa_def *src1_vec = nir_ssa_for_alu_src(b, alu, 1);
 
-      nir_ssa_def *sum[4];
-      for (unsigned i = 0; i < 3; i++) {
-         sum[i] = nir_fmul(b, nir_channel(b, src0_vec, i),
-                              nir_channel(b, src1_vec, i));
+      /* Only use reverse order for imprecise fdph, see explanation in lower_fdot. */
+      bool reverse_order = !b->exact;
+      if (will_lower_ffma(b->shader, alu->dest.dest.ssa.bit_size)) {
+         nir_ssa_def *sum[4];
+         for (unsigned i = 0; i < 3; i++) {
+            int dest = reverse_order ? 3 - i : i;
+            sum[dest] = nir_fmul(b, nir_channel(b, src0_vec, i),
+                                    nir_channel(b, src1_vec, i));
+         }
+         sum[reverse_order ? 0 : 3] = nir_channel(b, src1_vec, 3);
+
+         return nir_fadd(b, nir_fadd(b, nir_fadd(b, sum[0], sum[1]), sum[2]), sum[3]);
+      } else if (reverse_order) {
+         nir_ssa_def *sum = nir_channel(b, src1_vec, 3);
+         for (int i = 2; i >= 0; i--)
+            sum = nir_ffma(b, nir_channel(b, src0_vec, i), nir_channel(b, src1_vec, i), sum);
+         return sum;
+      } else {
+         nir_ssa_def *sum = nir_fmul(b, nir_channel(b, src0_vec, 0), nir_channel(b, src1_vec, 0));
+         sum = nir_ffma(b, nir_channel(b, src0_vec, 1), nir_channel(b, src1_vec, 1), sum);
+         sum = nir_ffma(b, nir_channel(b, src0_vec, 2), nir_channel(b, src1_vec, 2), sum);
+         return nir_fadd(b, sum, nir_channel(b, src1_vec, 3));
       }
-      sum[3] = nir_channel(b, src1_vec, 3);
-
-      return nir_fadd(b, nir_fadd(b, sum[0], sum[1]),
-                         nir_fadd(b, sum[2], sum[3]));
    }
 
    case nir_op_pack_64_2x32: {