nir/lower_shader_calls: add an option structure for future optimizations
authorLionel Landwerlin <lionel.g.landwerlin@intel.com>
Wed, 19 Oct 2022 13:33:20 +0000 (16:33 +0300)
committerMarge Bot <emma+marge@anholt.net>
Wed, 26 Oct 2022 12:53:25 +0000 (12:53 +0000)
Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16556>

src/amd/vulkan/radv_pipeline_rt.c
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_shader_calls.c
src/intel/vulkan/anv_pipeline.c

index 23dd710..0cb765e 100644 (file)
@@ -1620,9 +1620,13 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
        */
       NIR_PASS_V(nir_stage, move_rt_instructions);
 
+      const nir_lower_shader_calls_options opts = {
+         .address_format = nir_address_format_32bit_offset,
+         .stack_alignment = 16,
+      };
       uint32_t num_resume_shaders = 0;
       nir_shader **resume_shaders = NULL;
-      nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,
+      nir_lower_shader_calls(nir_stage, &opts, &resume_shaders,
                              &num_resume_shaders, nir_stage);
 
       vars.stage_idx = i;
index 66c3047..c77cd52 100644 (file)
@@ -4834,10 +4834,17 @@ bool nir_lower_explicit_io(nir_shader *shader,
                            nir_variable_mode modes,
                            nir_address_format);
 
+typedef struct nir_lower_shader_calls_options {
+   /* Address format used for load/store operations on the call stack. */
+   nir_address_format address_format;
+
+   /* Stack alignment */
+   unsigned stack_alignment;
+} nir_lower_shader_calls_options;
+
 bool
 nir_lower_shader_calls(nir_shader *shader,
-                       nir_address_format address_format,
-                       unsigned stack_alignment,
+                       const nir_lower_shader_calls_options *options,
                        nir_shader ***resume_shaders_out,
                        uint32_t *num_resume_shaders_out,
                        void *mem_ctx);
index 8ec5a4b..e34c8c0 100644 (file)
@@ -1424,8 +1424,7 @@ nir_opt_remove_respills(nir_shader *shader)
  */
 bool
 nir_lower_shader_calls(nir_shader *shader,
-                       nir_address_format address_format,
-                       unsigned stack_alignment,
+                       const nir_lower_shader_calls_options *options,
                        nir_shader ***resume_shaders_out,
                        uint32_t *num_resume_shaders_out,
                        void *mem_ctx)
@@ -1461,7 +1460,7 @@ nir_lower_shader_calls(nir_shader *shader,
    }
 
    NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
-              num_calls, stack_alignment);
+              num_calls, options->stack_alignment);
 
    NIR_PASS_V(shader, nir_opt_remove_phis);
 
@@ -1494,9 +1493,12 @@ nir_lower_shader_calls(nir_shader *shader,
    for (unsigned i = 0; i < num_calls; i++)
       NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills);
 
-   NIR_PASS_V(shader, nir_lower_stack_to_scratch, address_format);
-   for (unsigned i = 0; i < num_calls; i++)
-      NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch, address_format);
+   NIR_PASS_V(shader, nir_lower_stack_to_scratch,
+              options->address_format);
+   for (unsigned i = 0; i < num_calls; i++) {
+      NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch,
+                 options->address_format);
+   }
 
    *resume_shaders_out = resume_shaders;
    *num_resume_shaders_out = num_calls;
index f2cdeb2..b20efaa 100644 (file)
@@ -2465,9 +2465,12 @@ compile_upload_rt_shader(struct anv_ray_tracing_pipeline *pipeline,
    nir_shader **resume_shaders = NULL;
    uint32_t num_resume_shaders = 0;
    if (nir->info.stage != MESA_SHADER_COMPUTE) {
-      NIR_PASS(_, nir, nir_lower_shader_calls,
-               nir_address_format_64bit_global,
-               BRW_BTD_STACK_ALIGN,
+      const nir_lower_shader_calls_options opts = {
+         .address_format = nir_address_format_64bit_global,
+         .stack_alignment = BRW_BTD_STACK_ALIGN,
+      };
+
+      NIR_PASS(_, nir, nir_lower_shader_calls, &opts,
                &resume_shaders, &num_resume_shaders, mem_ctx);
       NIR_PASS(_, nir, brw_nir_lower_shader_calls, &stage->key.bs);
       NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);