radv: Use LDS for closest-hit hit attributes
authorFriedrich Vock <friedrich.vock@gmx.de>
Tue, 21 Feb 2023 11:42:53 +0000 (12:42 +0100)
committerMarge Bot <emma+marge@anholt.net>
Sun, 5 Mar 2023 21:53:34 +0000 (21:53 +0000)
Q2RTX: 23.1ms -> 22.9ms

shader-db:
Totals from 19 (0.69% of 2764) affected shaders:

MaxWaves: 197 -> 208 (+5.58%)
Instrs: 87702 -> 87817 (+0.13%); split: -0.03%, +0.16%
CodeSize: 474320 -> 475128 (+0.17%)
VGPRs: 1840 -> 1728 (-6.09%)
Latency: 2771599 -> 2773173 (+0.06%); split: -0.13%, +0.18%
InvThroughput: 561281 -> 533010 (-5.04%); split: -5.16%, +0.12%
VClause: 2782 -> 2788 (+0.22%); split: -0.18%, +0.40%
Copies: 12115 -> 12136 (+0.17%); split: -0.45%, +0.63%
Branches: 4116 -> 4122 (+0.15%)
PreVGPRs: 1665 -> 1638 (-1.62%); split: -1.92%, +0.30%
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21635>

src/amd/vulkan/radv_rt_shader.c

index d8fc762..cada611 100644 (file)
@@ -814,19 +814,15 @@ lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workg
          nir_instr_remove(instr);
       }
    }
-
-   if (hit_attribs) {
-      nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
-
-      nir_lower_global_vars_to_local(shader);
-      nir_lower_vars_to_ssa(shader);
-   }
 }
 
 static void
 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
                uint32_t call_idx_base, uint32_t call_idx)
 {
+   uint32_t workgroup_size = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] *
+                             b->shader->info.workgroup_size[2];
+
    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
 
    nir_opt_dead_cf(shader);
@@ -841,6 +837,10 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
    NIR_PASS(_, shader, nir_lower_returns);
    NIR_PASS(_, shader, nir_opt_dce);
 
+   /* The traversal shader has a call_idx of 1 */
+   if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || call_idx == 1)
+      NIR_PASS_V(shader, lower_hit_attribs, NULL, workgroup_size);
+
    reserve_stack_size(vars, shader->scratch_size);
 
    nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
@@ -1393,6 +1393,13 @@ build_traversal_shader(struct radv_device *device,
       device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
    struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key);
 
+   /* Register storage for hit attributes */
+   nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
+
+   for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
+      hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader),
+                                                 glsl_uint_type(), "ahit_attrib");
+
    nir_variable *barycentrics = nir_variable_create(
       b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
    barycentrics->data.driver_location = 0;
@@ -1494,9 +1501,15 @@ build_traversal_shader(struct radv_device *device,
 
    radv_build_ray_traversal(device, &b, &args);
 
+   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
+   lower_hit_attrib_derefs(b.shader);
+   lower_hit_attribs(b.shader, hit_attribs, device->physical_device->rt_wave_size);
+
    /* Initialize follow-up shader. */
    nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
    {
+      for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
+         nir_store_hit_attrib_amd(&b, nir_load_var(&b, hit_attribs[i]), .base = i);
       nir_execute_closest_hit_amd(
          &b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
          nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
@@ -1518,8 +1531,6 @@ build_traversal_shader(struct radv_device *device,
    NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
    NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
 
-   lower_hit_attrib_derefs(b.shader);
-
    return b.shader;
 }
 
@@ -1575,6 +1586,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    b.shader->info.internal = false;
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
+   b.shader->info.shared_size = device->physical_device->rt_wave_size * RADV_MAX_HIT_ATTRIB_SIZE;
 
    struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key);
    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX);
@@ -1591,11 +1603,6 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    };
    nir_store_var(&b, vars.launch_size, nir_vec(&b, xyz, 3), 0x7);
 
-   nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
-   for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
-      hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader),
-                                                 glsl_uint_type(), "attribute");
-
    nir_loop *loop = nir_push_loop(&b);
 
    nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
@@ -1606,8 +1613,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
 
    /* Insert traversal shader */
    nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, handles, key);
-   assert(b.shader->info.shared_size == 0);
-   b.shader->info.shared_size = traversal->info.shared_size;
+   b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
    assert(b.shader->info.shared_size <= 32768);
    insert_rt_case(&b, traversal, &vars, idx, 0, 1);
 
@@ -1667,7 +1673,5 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
 
-   lower_hit_attribs(b.shader, hit_attribs, device->physical_device->rt_wave_size);
-
    return b.shader;
 }