nir/lower_subgroups: Don't do multiple lowerings at once
authorConnor Abbott <cwabbott0@gmail.com>
Fri, 1 Feb 2019 10:01:31 +0000 (11:01 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 20 Sep 2023 14:41:18 +0000 (14:41 +0000)
Since using nir_shader_lower_instructions(), instructions get revisited
before proceeding with the next one. This already guarantees that any
subsequent lowerings of those instructions happen during the same pass
of nir_lower_subgroups().

v2: use nir_shader_lower_instructions() instead of setting the cursor.

Co-authored-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25123>

src/compiler/nir/nir_lower_subgroups.c

index afc31ea..2177634 100644 (file)
@@ -104,8 +104,7 @@ uint_to_ballot_type(nir_builder *b, nir_def *value,
 }
 
 static nir_def *
-lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
-                            bool lower_to_32bit)
+lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
 {
    /* This is safe to call on scalar things but it would be silly */
    assert(intrin->def.num_components > 1);
@@ -131,12 +130,8 @@ lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
       chan_intrin->const_index[0] = intrin->const_index[0];
       chan_intrin->const_index[1] = intrin->const_index[1];
 
-      if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
-         reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
-      } else {
-         nir_builder_instr_insert(b, &chan_intrin->instr);
-         reads[i] = &chan_intrin->def;
-      }
+      nir_builder_instr_insert(b, &chan_intrin->instr);
+      reads[i] = &chan_intrin->def;
    }
 
    return nir_vec(b, reads, intrin->num_components);
@@ -195,8 +190,7 @@ lower_vote_eq(nir_builder *b, nir_intrinsic_instr *intrin)
 }
 
 static nir_def *
-lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
-                         const nir_lower_subgroups_options *options)
+lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin)
 {
    unsigned mask = nir_src_as_uint(intrin->src[1]);
 
@@ -211,14 +205,8 @@ lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
    nir_def_init(&swizzle->instr, &swizzle->def,
                 intrin->def.num_components, intrin->def.bit_size);
 
-   if (options->lower_to_scalar && swizzle->num_components > 1) {
-      return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
-   } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
-      return lower_subgroup_op_to_32bit(b, swizzle);
-   } else {
-      nir_builder_instr_insert(b, &swizzle->instr);
-      return &swizzle->def;
-   }
+   nir_builder_instr_insert(b, &swizzle->instr);
+   return &swizzle->def;
 }
 
 /* Lowers "specialized" shuffles to a generic nir_intrinsic_shuffle. */
@@ -230,26 +218,22 @@ lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
    if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
        options->lower_shuffle_to_swizzle_amd &&
        nir_src_is_const(intrin->src[1])) {
-      nir_def *result =
-         lower_shuffle_to_swizzle(b, intrin, options);
+
+      nir_def *result = lower_shuffle_to_swizzle(b, intrin);
       if (result)
          return result;
    }
 
    nir_def *index = nir_load_subgroup_invocation(b);
-   bool is_shuffle = false;
    switch (intrin->intrinsic) {
    case nir_intrinsic_shuffle_xor:
       index = nir_ixor(b, index, intrin->src[1].ssa);
-      is_shuffle = true;
       break;
    case nir_intrinsic_shuffle_up:
       index = nir_isub(b, index, intrin->src[1].ssa);
-      is_shuffle = true;
       break;
    case nir_intrinsic_shuffle_down:
       index = nir_iadd(b, index, intrin->src[1].ssa);
-      is_shuffle = true;
       break;
    case nir_intrinsic_quad_broadcast:
       index = nir_ior(b, nir_iand_imm(b, index, ~0x3),
@@ -301,15 +285,8 @@ lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
    nir_def_init(&shuffle->instr, &shuffle->def,
                 intrin->def.num_components, intrin->def.bit_size);
 
-   bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
-   if (options->lower_to_scalar && shuffle->num_components > 1) {
-      return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
-   } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
-      return lower_subgroup_op_to_32bit(b, shuffle);
-   } else {
-      nir_builder_instr_insert(b, &shuffle->instr);
-      return &shuffle->def;
-   }
+   nir_builder_instr_insert(b, &shuffle->instr);
+   return &shuffle->def;
 }
 
 static const struct glsl_type *
@@ -587,7 +564,7 @@ lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
       nir_def *qbcst_dst = NULL;
 
       if (options->lower_to_scalar && qbcst->num_components > 1) {
-         qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
+         qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst);
       } else {
          nir_builder_instr_insert(b, &qbcst->instr);
          qbcst_dst = &qbcst->def;
@@ -644,7 +621,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 
    case nir_intrinsic_read_invocation:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, false);
+         return lower_subgroup_op_to_scalar(b, intrin);
 
       if (options->lower_read_invocation_to_cond)
          return lower_read_invocation_to_cond(b, intrin);
@@ -653,7 +630,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 
    case nir_intrinsic_read_first_invocation:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, false);
+         return lower_subgroup_op_to_scalar(b, intrin);
       break;
 
    case nir_intrinsic_load_subgroup_eq_mask:
@@ -799,7 +776,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       if (options->lower_shuffle)
          return lower_shuffle(b, intrin);
       else if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+         return lower_subgroup_op_to_scalar(b, intrin);
       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
          return lower_subgroup_op_to_32bit(b, intrin);
       break;
@@ -809,7 +786,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       if (options->lower_relative_shuffle)
          return lower_to_shuffle(b, intrin, options);
       else if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+         return lower_subgroup_op_to_scalar(b, intrin);
       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
          return lower_subgroup_op_to_32bit(b, intrin);
       break;
@@ -824,7 +801,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
            !nir_src_is_const(intrin->src[1])))
          return lower_dynamic_quad_broadcast(b, intrin, options);
       else if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, false);
+         return lower_subgroup_op_to_scalar(b, intrin);
       break;
 
    case nir_intrinsic_reduce: {
@@ -836,13 +813,13 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
          ret = NIR_LOWER_INSTR_PROGRESS;
       }
       if (options->lower_to_scalar && intrin->num_components > 1)
-         ret = lower_subgroup_op_to_scalar(b, intrin, false);
+         ret = lower_subgroup_op_to_scalar(b, intrin);
       return ret;
    }
    case nir_intrinsic_inclusive_scan:
    case nir_intrinsic_exclusive_scan:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin, false);
+         return lower_subgroup_op_to_scalar(b, intrin);
       break;
 
    case nir_intrinsic_rotate:
@@ -850,7 +827,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
          if (options->lower_rotate_to_shuffle)
             return lower_to_shuffle(b, intrin, options);
          else if (options->lower_to_scalar && intrin->num_components > 1)
-            return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+            return lower_subgroup_op_to_scalar(b, intrin);
          else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
             return lower_subgroup_op_to_32bit(b, intrin);
       }