nir/opt_intrinsic: optimize quad vote
authorRhys Perry <pendingchaos02@gmail.com>
Tue, 13 Jun 2023 13:07:53 +0000 (14:07 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 27 Jun 2023 18:53:50 +0000 (18:53 +0000)
Optimizes a quadAll()/quadAny() pattern created by dxil-spirv:
https://github.com/HansKristian-Work/dxil-spirv/commit/7adc87d4deaba8078bcdef8dfbebdda0165cd7bc

dxil-spirv can't use clustered reductions because they are not guaranteed
to include helper invocations.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23621>

src/compiler/nir/nir.h
src/compiler/nir/nir_opt_intrinsics.c

index fd113ad..10e4731 100644 (file)
@@ -3658,6 +3658,12 @@ typedef struct nir_shader_compiler_options {
     */
    bool optimize_sample_mask_in;
 
+   /**
+    * Optimize boolean reductions of quad broadcasts. This should only be enabled if
+    * nir_intrinsic_reduce supports INCLUDE_HELPERS.
+    */
+   bool optimize_quad_vote_to_reduce;
+
    bool lower_cs_local_index_to_id;
    bool lower_cs_local_id_to_index;
 
index e4ed2b7..fdd502f 100644 (file)
@@ -93,8 +93,130 @@ try_opt_bcsel_of_shuffle(nir_builder *b, nir_alu_instr *alu,
 }
 
 static bool
+src_is_quad_broadcast(nir_block *block, nir_src src, nir_intrinsic_instr **intrin)
+{
+   nir_intrinsic_instr *broadcast = nir_src_as_intrinsic(src);
+   if (broadcast == NULL || broadcast->instr.block != block)
+      return false;
+
+   switch (broadcast->intrinsic) {
+   case nir_intrinsic_quad_broadcast:
+      if (!nir_src_is_const(broadcast->src[1]))
+         return false;
+      FALLTHROUGH;
+   case nir_intrinsic_quad_swap_horizontal:
+   case nir_intrinsic_quad_swap_vertical:
+   case nir_intrinsic_quad_swap_diagonal:
+   case nir_intrinsic_quad_swizzle_amd:
+      *intrin = broadcast;
+      return true;
+   default:
+      return false;
+   }
+}
+
+static bool
+src_is_alu(nir_op op, nir_src src, nir_src srcs[2])
+{
+   nir_alu_instr *alu = nir_src_as_alu_instr(src);
+   if (alu == NULL || alu->op != op)
+      return false;
+
+   if (!nir_alu_src_is_trivial_ssa(alu, 0) || !nir_alu_src_is_trivial_ssa(alu, 1))
+      return false;
+
+   srcs[0] = alu->src[0].src;
+   srcs[1] = alu->src[1].src;
+
+   return true;
+}
+
+static nir_ssa_def *
+try_opt_quad_vote(nir_builder *b, nir_alu_instr *alu, bool block_has_discard)
+{
+   if (block_has_discard)
+      return NULL;
+
+   if (!nir_alu_src_is_trivial_ssa(alu, 0) || !nir_alu_src_is_trivial_ssa(alu, 1))
+      return NULL;
+
+   nir_intrinsic_instr *quad_broadcasts[4];
+   nir_src srcs[2][2];
+   bool found = false;
+
+   /* Match (broadcast0 op broadcast1) op (broadcast2 op broadcast3). */
+   found = src_is_alu(alu->op, alu->src[0].src, srcs[0]) &&
+           src_is_alu(alu->op, alu->src[1].src, srcs[1]) &&
+           src_is_quad_broadcast(alu->instr.block, srcs[0][0], &quad_broadcasts[0]) &&
+           src_is_quad_broadcast(alu->instr.block, srcs[0][1], &quad_broadcasts[1]) &&
+           src_is_quad_broadcast(alu->instr.block, srcs[1][0], &quad_broadcasts[2]) &&
+           src_is_quad_broadcast(alu->instr.block, srcs[1][1], &quad_broadcasts[3]);
+
+   /* Match ((broadcast2 op broadcast3) op broadcast1) op broadcast0). */
+   if (!found) {
+      if ((src_is_alu(alu->op, alu->src[0].src, srcs[0]) &&
+           src_is_quad_broadcast(alu->instr.block, alu->src[1].src, &quad_broadcasts[0])) ||
+          (src_is_alu(alu->op, alu->src[1].src, srcs[0]) &&
+           src_is_quad_broadcast(alu->instr.block, alu->src[0].src, &quad_broadcasts[0]))) {
+         /* ((broadcast2 || broadcast3) || broadcast1) */
+         if ((src_is_alu(alu->op, srcs[0][0], srcs[1]) &&
+              src_is_quad_broadcast(alu->instr.block, srcs[0][1], &quad_broadcasts[1])) ||
+             (src_is_alu(alu->op, srcs[0][1], srcs[1]) &&
+              src_is_quad_broadcast(alu->instr.block, srcs[0][0], &quad_broadcasts[1]))) {
+            /* (broadcast2 || broadcast3) */
+            found = src_is_quad_broadcast(alu->instr.block, srcs[1][0], &quad_broadcasts[2]) &&
+                    src_is_quad_broadcast(alu->instr.block, srcs[1][1], &quad_broadcasts[3]);
+         }
+      }
+   }
+
+   if (!found)
+      return NULL;
+
+   /* Check if each lane in a quad reduces all lanes in the quad, and if all broadcasts read the
+    * same data.
+    */
+   uint16_t lanes_read = 0;
+   for (unsigned i = 0; i < 4; i++) {
+      if (!nir_srcs_equal(quad_broadcasts[i]->src[0], quad_broadcasts[0]->src[0]))
+         return NULL;
+
+      for (unsigned j = 0; j < 4; j++) {
+         unsigned lane;
+         switch (quad_broadcasts[i]->intrinsic) {
+         case nir_intrinsic_quad_broadcast:
+            lane = nir_src_as_uint(quad_broadcasts[i]->src[1]) & 0x3;
+            break;
+         case nir_intrinsic_quad_swap_horizontal:
+            lane = j ^ 1;
+            break;
+         case nir_intrinsic_quad_swap_vertical:
+            lane = j ^ 2;
+            break;
+         case nir_intrinsic_quad_swap_diagonal:
+            lane = 3 - j;
+            break;
+         case nir_intrinsic_quad_swizzle_amd:
+            lane = (nir_intrinsic_swizzle_mask(quad_broadcasts[i]) >> (j * 2)) & 0x3;
+            break;
+         default:
+            unreachable();
+         }
+         lanes_read |= (1 << lane) << (j * 4);
+      }
+   }
+
+   if (lanes_read != 0xffff)
+      return NULL;
+
+   /* Create reduction. */
+   return nir_reduce(b, quad_broadcasts[0]->src[0].ssa, .reduction_op = alu->op, .cluster_size = 4,
+                     .include_helpers = true);
+}
+
+static bool
 opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu,
-                   bool block_has_discard)
+                   bool block_has_discard, const struct nir_shader_compiler_options *options)
 {
    nir_ssa_def *replacement = NULL;
 
@@ -102,7 +224,11 @@ opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu,
    case nir_op_bcsel:
       replacement = try_opt_bcsel_of_shuffle(b, alu, block_has_discard);
       break;
-
+   case nir_op_iand:
+   case nir_op_ior:
+      if (nir_dest_bit_size(alu->dest.dest) == 1 && options->optimize_quad_vote_to_reduce)
+         replacement = try_opt_quad_vote(b, alu, block_has_discard);
+      break;
    default:
       break;
    }
@@ -181,7 +307,7 @@ opt_intrinsics_impl(nir_function_impl *impl,
          switch (instr->type) {
          case nir_instr_type_alu:
             if (opt_intrinsics_alu(&b, nir_instr_as_alu(instr),
-                                   block_has_discard))
+                                   block_has_discard, options))
                progress = true;
             break;