radv/rt: use prolog for raytracing shaders
authorDaniel Schürmann <daniel@schuermann.dev>
Tue, 21 Feb 2023 16:37:04 +0000 (17:37 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 16 Mar 2023 01:40:30 +0000 (01:40 +0000)
Co-authored-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21780>

src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader_args.c

index 5e0c9ca..32f936d 100644 (file)
@@ -611,7 +611,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
    VkPipelineShaderStageCreateInfo stage = {
       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
       .pNext = NULL,
-      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
+      .stage = VK_SHADER_STAGE_RAYGEN_BIT_KHR,
       .module = vk_shader_module_to_handle(&module),
       .pName = "main",
    };
@@ -664,13 +664,18 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
          goto shader_fail;
    }
 
-   radv_compute_pipeline_init(&rt_pipeline->base, pipeline_layout);
-
    rt_pipeline->stack_size = compute_rt_stack_size(pCreateInfo, rt_pipeline->groups);
    rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE] = radv_create_rt_prolog(device);
 
-   *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base);
+   combine_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config,
+                  &rt_pipeline->base.base.shaders[MESA_SHADER_RAYGEN]->config);
+
+   postprocess_rt_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config,
+                         device->physical_device->rt_wave_size);
 
+   radv_compute_pipeline_init(&rt_pipeline->base, pipeline_layout);
+
+   *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base);
 shader_fail:
    ralloc_free(shader);
 pipeline_fail:
index 9edaa6c..0399021 100644 (file)
@@ -97,9 +97,6 @@ struct rt_variables {
    /* global address of the SBT entry used for the shader */
    nir_variable *shader_record_ptr;
 
-   nir_variable *launch_size;
-   nir_variable *launch_id;
-
    /* trace_ray arguments */
    nir_variable *accel_struct;
    nir_variable *cull_mask_and_flags;
@@ -137,10 +134,6 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR
    vars.shader_record_ptr =
       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
 
-   const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3);
-   vars.launch_size = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_size");
-   vars.launch_id = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_id");
-
    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
    vars.accel_struct =
       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
@@ -187,8 +180,6 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
-   _mesa_hash_table_insert(var_remap, src->launch_size, dst->launch_size);
-   _mesa_hash_table_insert(var_remap, src->launch_id, dst->launch_id);
 
    _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
    _mesa_hash_table_insert(var_remap, src->cull_mask_and_flags, dst->cull_mask_and_flags);
@@ -403,14 +394,6 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                ret = nir_load_var(&b_shader, vars->shader_record_ptr);
                break;
             }
-            case nir_intrinsic_load_ray_launch_id: {
-               ret = nir_load_var(&b_shader, vars->launch_id);
-               break;
-            }
-            case nir_intrinsic_load_ray_launch_size: {
-               ret = nir_load_var(&b_shader, vars->launch_size);
-               break;
-            }
             case nir_intrinsic_load_ray_t_min: {
                ret = nir_load_var(&b_shader, vars->tmin);
                break;
@@ -1429,8 +1412,6 @@ build_traversal_shader(struct radv_device *device,
    nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
    nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
-   nir_store_var(&b, vars.launch_size, nir_load_ray_launch_size(&b), 0x7);
-   nir_store_var(&b, vars.launch_id, nir_load_ray_launch_id(&b), 0x7);
 
    struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
 
@@ -1594,7 +1575,7 @@ nir_shader *
 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                  struct radv_ray_tracing_module *groups, const struct radv_pipeline_key *key)
 {
-   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
+   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined");
    b.shader->info.internal = false;
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
@@ -1604,17 +1585,6 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX);
    nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1);
 
-   nir_store_var(&b, vars.launch_id, nir_load_global_invocation_id(&b, 32), 0x7);
-   nir_ssa_def *launch_size_addr = nir_load_ray_launch_size_addr_amd(&b);
-   nir_ssa_def *xy = nir_build_load_smem_amd(&b, 2, launch_size_addr, nir_imm_int(&b, 0));
-   nir_ssa_def *z = nir_build_load_smem_amd(&b, 1, launch_size_addr, nir_imm_int(&b, 8));
-   nir_ssa_def *xyz[3] = {
-      nir_channel(&b, xy, 0),
-      nir_channel(&b, xy, 1),
-      z,
-   };
-   nir_store_var(&b, vars.launch_size, nir_vec(&b, xyz, 3), 0x7);
-
    nir_loop *loop = nir_push_loop(&b);
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
index 456dc31..5b9ba75 100644 (file)
@@ -638,10 +638,6 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip
    case MESA_SHADER_TASK:
       declare_global_input_sgprs(info, &user_sgpr_info, args);
 
-      if (info->cs.is_rt_shader) {
-         ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.sbt_descriptors);
-      }
-
       if (info->cs.uses_grid_size) {
          if (args->load_grid_size_from_user_sgpr)
             ac_add_arg(&args->ac, AC_ARG_SGPR, 3, AC_ARG_INT, &args->ac.num_work_groups);
@@ -649,15 +645,13 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip
             ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.num_work_groups);
       }
 
-      if (info->cs.uses_ray_launch_size) {
+      if (info->cs.is_rt_shader) {
+         ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_DESC_PTR, &args->ac.sbt_descriptors);
          ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.ray_launch_size_addr);
-      }
-
-      if (info->cs.uses_dynamic_rt_callable_stack) {
-         ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
-                    &args->ac.rt_dynamic_callable_stack_base);
          ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR,
                     &args->ac.rt_traversal_shader_addr);
+         ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                    &args->ac.rt_dynamic_callable_stack_base);
       }
 
       if (info->vs.needs_draw_id) {
@@ -934,12 +928,12 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip
       if (args->ac.ray_launch_size_addr.used) {
          set_loc_shader_ptr(args, AC_UD_CS_RAY_LAUNCH_SIZE_ADDR, &user_sgpr_idx);
       }
-      if (args->ac.rt_dynamic_callable_stack_base.used) {
-         set_loc_shader(args, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE, &user_sgpr_idx, 1);
-      }
       if (args->ac.rt_traversal_shader_addr.used) {
          set_loc_shader_ptr(args, AC_UD_CS_TRAVERSAL_SHADER_ADDR, &user_sgpr_idx);
       }
+      if (args->ac.rt_dynamic_callable_stack_base.used) {
+         set_loc_shader(args, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE, &user_sgpr_idx, 1);
+      }
       if (args->ac.draw_id.used) {
          set_loc_shader(args, AC_UD_CS_TASK_DRAW_ID, &user_sgpr_idx, 1);
       }