radv: handle RT stages in radv_nir_shader_info_pass()
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 13 May 2022 14:10:01 +0000 (16:10 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 16 Mar 2023 01:40:29 +0000 (01:40 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21780>

src/amd/vulkan/radv_shader_info.c

index 4d4d20f..b95c069 100644 (file)
@@ -342,12 +342,14 @@ radv_get_wave_size(struct radv_device *device,  gl_shader_stage stage,
 {
    if (stage == MESA_SHADER_GEOMETRY && !info->is_ngg)
       return 64;
-   else if (stage == MESA_SHADER_COMPUTE) {
+   else if (stage == MESA_SHADER_COMPUTE)
       return info->cs.subgroup_size;
-   else if (stage == MESA_SHADER_FRAGMENT)
+   else if (stage == MESA_SHADER_FRAGMENT)
       return device->physical_device->ps_wave_size;
    else if (stage == MESA_SHADER_TASK)
       return device->physical_device->cs_wave_size;
+   else if (gl_shader_stage_is_rt(stage))
+      return device->physical_device->rt_wave_size;
    else
       return device->physical_device->ge_wave_size;
 }
@@ -358,6 +360,8 @@ radv_get_ballot_bit_size(struct radv_device *device, gl_shader_stage stage,
 {
    if (stage == MESA_SHADER_COMPUTE && info->cs.subgroup_size)
       return info->cs.subgroup_size;
+   else if (gl_shader_stage_is_rt(stage))
+      return device->physical_device->rt_wave_size;
    return 64;
 }
 
@@ -632,6 +636,17 @@ gather_shader_info_fs(const nir_shader *nir, const struct radv_pipeline_key *pip
 }
 
 static void
+gather_shader_info_rt(const nir_shader *nir, struct radv_shader_info *info)
+{
+   // TODO: inline push_constants again
+   info->loads_dynamic_offsets = true;
+   info->loads_push_constants = true;
+   info->can_inline_all_push_constants = false;
+   info->inline_push_constant_mask = 0;
+   info->desc_set_used_mask = -1u;
+}
+
+static void
 gather_shader_info_cs(struct radv_device *device, const nir_shader *nir,
                       const struct radv_pipeline_key *pipeline_key, struct radv_shader_info *info)
 {
@@ -832,6 +847,8 @@ radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *n
       gather_shader_info_mesh(nir, info);
       break;
    default:
+      if (gl_shader_stage_is_rt(nir->info.stage))
+         gather_shader_info_rt(nir, info);
       break;
    }