nir: improve ms_cross_invocation_output_access with local_invocation_id
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 31 Aug 2023 19:35:25 +0000 (20:35 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 24 Oct 2023 21:36:06 +0000 (21:36 +0000)
Since GFX11, RADV doesn't need to lower local_invocation_id.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25040>

src/compiler/nir/nir_gather_info.c

index d020189..eb0734e 100644 (file)
@@ -36,11 +36,25 @@ src_is_invocation_id(const nir_src *src)
 }
 
 static bool
-src_is_local_invocation_index(const nir_src *src)
+src_is_local_invocation_index(nir_shader *shader, const nir_src *src)
 {
+   assert(shader->info.stage == MESA_SHADER_MESH && !shader->info.workgroup_size_variable);
+
    nir_scalar s = nir_scalar_resolved(src->ssa, 0);
-   return nir_scalar_is_intrinsic(s) &&
-          nir_scalar_intrinsic_op(s) == nir_intrinsic_load_local_invocation_index;
+   if (!nir_scalar_is_intrinsic(s))
+      return false;
+
+   const nir_intrinsic_op op = nir_scalar_intrinsic_op(s);
+   if (op == nir_intrinsic_load_local_invocation_index)
+      return true;
+   if (op != nir_intrinsic_load_local_invocation_id)
+      return false;
+
+   unsigned nz_ids = 0;
+   for (unsigned i = 0; i < 3; i++)
+      nz_ids |= (shader->info.workgroup_size[i] > 1) ? (1u << i) : 0;
+
+   return nz_ids == 0 || (util_bitcount(nz_ids) == 1 && s.comp == ffs(nz_ids) - 1);
 }
 
 static void
@@ -63,7 +77,7 @@ get_deref_info(nir_shader *shader, nir_variable *var, nir_deref_instr *deref,
       if (shader->info.stage == MESA_SHADER_TESS_CTRL)
          *cross_invocation = !src_is_invocation_id(&(*p)->arr.index);
       else if (shader->info.stage == MESA_SHADER_MESH)
-         *cross_invocation = !src_is_local_invocation_index(&(*p)->arr.index);
+         *cross_invocation = !src_is_local_invocation_index(shader, &(*p)->arr.index);
       p++;
    }
 
@@ -549,7 +563,7 @@ gather_intrinsic_info(nir_intrinsic_instr *instr, nir_shader *shader,
       if (shader->info.stage == MESA_SHADER_MESH &&
           (instr->intrinsic == nir_intrinsic_load_per_vertex_output ||
            instr->intrinsic == nir_intrinsic_load_per_primitive_output) &&
-          !src_is_local_invocation_index(nir_get_io_arrayed_index_src(instr)))
+          !src_is_local_invocation_index(shader, nir_get_io_arrayed_index_src(instr)))
          shader->info.mesh.ms_cross_invocation_output_access |= slot_mask;
 
       if (shader->info.stage == MESA_SHADER_FRAGMENT &&
@@ -577,7 +591,7 @@ gather_intrinsic_info(nir_intrinsic_instr *instr, nir_shader *shader,
       if (shader->info.stage == MESA_SHADER_MESH &&
           (instr->intrinsic == nir_intrinsic_store_per_vertex_output ||
            instr->intrinsic == nir_intrinsic_store_per_primitive_output) &&
-          !src_is_local_invocation_index(nir_get_io_arrayed_index_src(instr)))
+          !src_is_local_invocation_index(shader, nir_get_io_arrayed_index_src(instr)))
          shader->info.mesh.ms_cross_invocation_output_access |= slot_mask;
 
       if (shader->info.stage == MESA_SHADER_FRAGMENT &&