radv/rt: Lower hit attributes to registers
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Wed, 9 Nov 2022 20:22:50 +0000 (21:22 +0100)
committerMarge Bot <emma+marge@anholt.net>
Fri, 9 Dec 2022 07:07:10 +0000 (07:07 +0000)
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19866>

src/amd/vulkan/radv_rt_shader.c

index 7e00ced..2397612 100644 (file)
  * performance. */
 #define MAX_STACK_ENTRY_COUNT 16
 
-/* The hit attributes are stored on the stack. This is the offset compared to the current stack
- * pointer of where the hit attrib is stored. */
-const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
-
 static bool
 lower_rt_derefs(nir_shader *shader)
 {
@@ -57,26 +53,16 @@ lower_rt_derefs(nir_shader *shader)
             continue;
 
          nir_deref_instr *deref = nir_instr_as_deref(instr);
-         b.cursor = nir_before_instr(&deref->instr);
-
-         nir_deref_instr *replacement = NULL;
-         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
-            deref->modes = nir_var_function_temp;
-            progress = true;
-
-            if (deref->deref_type == nir_deref_type_var)
-               replacement =
-                  nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
-         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
-            deref->modes = nir_var_function_temp;
-            progress = true;
-
-            if (deref->deref_type == nir_deref_type_var)
-               replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
-                                                  nir_var_function_temp, deref->type, 0);
-         }
+         if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
+            continue;
+
+         deref->modes = nir_var_function_temp;
+         progress = true;
 
-         if (replacement != NULL) {
+         if (deref->deref_type == nir_deref_type_var) {
+            b.cursor = nir_before_instr(&deref->instr);
+            nir_deref_instr *replacement =
+               nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
             nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
             nir_instr_remove(&deref->instr);
          }
@@ -337,7 +323,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
 
             switch (intr->intrinsic) {
             case nir_intrinsic_rt_execute_callable: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
                uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
 
                nir_store_var(
@@ -358,7 +344,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                break;
             }
             case nir_intrinsic_rt_trace_ray: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
                uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
 
                nir_store_var(
@@ -395,7 +381,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                break;
             }
             case nir_intrinsic_rt_resume: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
 
                nir_store_var(
                   &b_shader, vars->stack_ptr,
@@ -657,6 +643,109 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
    nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
 }
 
+static bool
+lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   if (intrin->intrinsic != nir_intrinsic_load_deref &&
+       intrin->intrinsic != nir_intrinsic_store_deref)
+      return false;
+
+   nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+   if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
+      return false;
+
+   assert(deref->deref_type == nir_deref_type_var);
+
+   b->cursor = nir_after_instr(instr);
+
+   if (intrin->intrinsic == nir_intrinsic_load_deref) {
+      uint32_t num_components = intrin->dest.ssa.num_components;
+      uint32_t bit_size = intrin->dest.ssa.bit_size;
+
+      nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
+
+      for (uint32_t comp = 0; comp < num_components; comp++) {
+         uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
+         uint32_t base = offset / 4;
+         uint32_t comp_offset = offset % 4;
+
+         if (bit_size == 64) {
+            components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
+                                                      nir_load_hit_attrib_amd(b, .base = base + 1));
+         } else if (bit_size == 32) {
+            components[comp] = nir_load_hit_attrib_amd(b, .base = base);
+         } else if (bit_size == 16) {
+            components[comp] = nir_channel(
+               b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
+         } else if (bit_size == 8) {
+            components[comp] = nir_channel(
+               b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
+         } else {
+            unreachable("Invalid bit_size");
+         }
+      }
+
+      nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_vec(b, components, num_components));
+   } else {
+      nir_ssa_def *value = intrin->src[1].ssa;
+      uint32_t num_components = value->num_components;
+      uint32_t bit_size = value->bit_size;
+
+      for (uint32_t comp = 0; comp < num_components; comp++) {
+         uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
+         uint32_t base = offset / 4;
+         uint32_t comp_offset = offset % 4;
+
+         nir_ssa_def *component = nir_channel(b, value, comp);
+
+         if (bit_size == 64) {
+            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
+            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
+         } else if (bit_size == 32) {
+            nir_store_hit_attrib_amd(b, component, .base = base);
+         } else if (bit_size == 16) {
+            nir_ssa_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
+            nir_ssa_def *components[2];
+            for (uint32_t word = 0; word < 2; word++)
+               components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp)
+                                                            : nir_channel(b, prev, word);
+            nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)),
+                                     .base = base);
+         } else if (bit_size == 8) {
+            nir_ssa_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
+            nir_ssa_def *components[4];
+            for (uint32_t byte = 0; byte < 4; byte++)
+               components[byte] =
+                  (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
+            nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)),
+                                     .base = base);
+         } else {
+            unreachable("Invalid bit_size");
+         }
+      }
+   }
+
+   nir_instr_remove(instr);
+   return true;
+}
+
+static bool
+lower_hit_attrib_derefs(nir_shader *shader)
+{
+   bool progress = nir_shader_instructions_pass(
+      shader, lower_hit_attrib_deref, nir_metadata_block_index | nir_metadata_dominance, NULL);
+   if (progress) {
+      nir_remove_dead_derefs(shader);
+      nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
+   }
+
+   return progress;
+}
+
 static void
 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
                uint32_t call_idx_base, uint32_t call_idx)
@@ -707,11 +796,16 @@ parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo
       nir_rt_return_amd(&b_inner);
    }
 
+   NIR_PASS(_, shader, nir_split_struct_vars, nir_var_ray_hit_attrib);
+   NIR_PASS(_, shader, nir_lower_indirect_derefs, nir_var_ray_hit_attrib, UINT32_MAX);
+   NIR_PASS(_, shader, nir_split_array_vars, nir_var_ray_hit_attrib);
+
    NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
             nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
             glsl_get_natural_size_align_bytes);
 
    NIR_PASS(_, shader, lower_rt_derefs);
+   NIR_PASS(_, shader, lower_hit_attrib_derefs);
 
    NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
             nir_address_format_32bit_offset);
@@ -1004,6 +1098,7 @@ struct traversal_data {
    const VkRayTracingPipelineCreateInfoKHR *createInfo;
    struct rt_variables *vars;
    struct rt_traversal_vars *trav_vars;
+   nir_variable *barycentrics;
 };
 
 static void
@@ -1023,10 +1118,8 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
    nir_ssa_def *hit_kind =
       nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
 
-   nir_ssa_def *barycentrics_addr =
-      nir_iadd_imm_nuw(b, nir_load_var(b, data->vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET);
-   nir_ssa_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_addr, .align_mul = 16);
-   nir_store_scratch(b, intersection->barycentrics, barycentrics_addr, .align_mul = 16);
+   nir_ssa_def *prev_barycentrics = nir_load_var(b, data->barycentrics);
+   nir_store_var(b, data->barycentrics, intersection->barycentrics, 0x3);
 
    nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
@@ -1049,7 +1142,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
 
       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
       {
-         nir_store_scratch(b, prev_barycentrics, barycentrics_addr, .align_mul = 16);
+         nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3);
          nir_jump(b, nir_jump_continue);
       }
       nir_pop_if(b, NULL);
@@ -1179,7 +1272,9 @@ build_traversal_shader(struct radv_device *device,
                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                        struct radv_pipeline_shader_stack_size *stack_sizes)
 {
-   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
+   /* Create the traversal shader as an intersection shader to prevent validation failures due to
+    * invalid variable modes.*/
+   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
    b.shader->info.internal = false;
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
@@ -1187,6 +1282,10 @@ build_traversal_shader(struct radv_device *device,
       device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
    struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
 
+   nir_variable *barycentrics = nir_variable_create(
+      b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
+   barycentrics->data.driver_location = 0;
+
    /* initialize trace_ray arguments */
    nir_ssa_def *accel_struct = nir_load_accel_struct_amd(&b);
    nir_store_var(&b, vars.flags, nir_load_ray_flags(&b), 0x1);
@@ -1257,6 +1356,7 @@ build_traversal_shader(struct radv_device *device,
          .createInfo = pCreateInfo,
          .vars = &vars,
          .trav_vars = &trav_vars,
+         .barycentrics = barycentrics,
       };
 
       struct radv_ray_traversal_args args = {
@@ -1308,6 +1408,8 @@ build_traversal_shader(struct radv_device *device,
    NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
    NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
 
+   lower_hit_attrib_derefs(b.shader);
+
    return b.shader;
 }
 
@@ -1410,6 +1512,41 @@ move_rt_instructions(nir_shader *shader)
                          nir_metadata_all & (~nir_metadata_instr_index));
 }
 
+static void
+lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+
+   nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib)
+      attrib->data.mode = nir_var_shader_temp;
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   nir_foreach_block (block, impl) {
+      nir_foreach_instr_safe (instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+
+         b.cursor = nir_after_instr(instr);
+         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+         if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
+            nir_ssa_def *ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
+            nir_ssa_def_rewrite_uses(nir_instr_ssa_def(instr), ret);
+            nir_instr_remove(instr);
+         } else if (intrin->intrinsic == nir_intrinsic_store_hit_attrib_amd) {
+            nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
+            nir_instr_remove(instr);
+         }
+      }
+   }
+
+   nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
+
+   nir_lower_global_vars_to_local(shader);
+   nir_lower_vars_to_ssa(shader);
+}
+
 nir_shader *
 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                  struct radv_pipeline_shader_stack_size *stack_sizes)
@@ -1429,6 +1566,11 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    else
       nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
 
+   nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
+   for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
+      hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader),
+                                                 glsl_uint_type(), "attribute");
+
    nir_loop *loop = nir_push_loop(&b);
 
    nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
@@ -1491,5 +1633,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
 
+   lower_hit_attribs(b.shader, hit_attribs);
+
    return b.shader;
 }