nir: lower ballot_bit_count_exclusive/inclusive to mbcnt_amd
authorGeorg Lehmann <dadschoorse@gmail.com>
Mon, 1 May 2023 17:04:03 +0000 (19:04 +0200)
committerMarge Bot <emma+marge@anholt.net>
Wed, 3 May 2023 10:39:20 +0000 (10:39 +0000)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22783>

src/compiler/nir/nir.h
src/compiler/nir/nir_lower_subgroups.c

index b802287..b8c1ee7 100644 (file)
@@ -5137,6 +5137,7 @@ typedef struct nir_lower_subgroups_options {
    bool lower_elect:1;
    bool lower_read_invocation_to_cond:1;
    bool lower_rotate_to_shuffle:1;
+   bool lower_ballot_bit_count_to_mbcnt_amd:1;
 } nir_lower_subgroups_options;
 
 bool nir_lower_subgroups(nir_shader *shader,
index 75a8ecd..9ace7f2 100644 (file)
@@ -774,6 +774,20 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 
    case nir_intrinsic_ballot_bit_count_exclusive:
    case nir_intrinsic_ballot_bit_count_inclusive: {
+      assert(intrin->src[0].is_ssa);
+      nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
+                                                 options);
+      if (options->lower_ballot_bit_count_to_mbcnt_amd) {
+         nir_ssa_def *acc;
+         if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_exclusive) {
+            acc = nir_imm_int(b, 0);
+         } else {
+            acc = nir_iand_imm(b, nir_u2u32(b, int_val), 0x1);
+            int_val = nir_ushr_imm(b, int_val, 1);
+         }
+         return nir_mbcnt_amd(b, int_val, acc);
+      }
+
       nir_ssa_def *mask;
       if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
          mask = nir_inot(b, build_subgroup_gt_mask(b, options));
@@ -781,10 +795,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
          mask = nir_inot(b, build_subgroup_ge_mask(b, options));
       }
 
-      assert(intrin->src[0].is_ssa);
-      nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
-                                                 options);
-
       return vec_bit_count(b, nir_iand(b, int_val, mask));
    }