radv: move assigning driver locations to the new linking helpers
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Fri, 19 Aug 2022 09:40:51 +0000 (11:40 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 22 Aug 2022 13:45:28 +0000 (13:45 +0000)
Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-By: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18138>

src/amd/vulkan/radv_pipeline.c

index 07f6eaf..841a5ce 100644 (file)
@@ -2784,6 +2784,28 @@ radv_pipeline_link_vs(const struct radv_device *device, struct radv_pipeline_sta
    }
 
    radv_pipeline_link_shaders(device, vs_stage->nir, next_stage->nir, pipeline_key);
+
+   nir_foreach_shader_in_variable(var, vs_stage->nir) {
+      var->data.driver_location = var->data.location;
+   }
+
+   if (next_stage->nir->info.stage == MESA_SHADER_TESS_CTRL) {
+      nir_linked_io_var_info vs2tcs =
+         nir_assign_linked_io_var_locations(vs_stage->nir, next_stage->nir);
+
+      vs_stage->info.vs.num_linked_outputs = vs2tcs.num_linked_io_vars;
+      next_stage->info.tcs.num_linked_inputs = vs2tcs.num_linked_io_vars;
+   } else if (next_stage->nir->info.stage == MESA_SHADER_GEOMETRY) {
+      nir_linked_io_var_info vs2gs =
+         nir_assign_linked_io_var_locations(vs_stage->nir, next_stage->nir);
+
+      vs_stage->info.vs.num_linked_outputs = vs2gs.num_linked_io_vars;
+      next_stage->info.gs.num_linked_inputs = vs2gs.num_linked_io_vars;
+   } else {
+      nir_foreach_shader_out_variable(var, vs_stage->nir) {
+         var->data.driver_location = var->data.location;
+      }
+   }
 }
 
 static void
@@ -2795,6 +2817,14 @@ radv_pipeline_link_tcs(const struct radv_device *device, struct radv_pipeline_st
    assert(tes_stage->nir->info.stage == MESA_SHADER_TESS_EVAL);
 
    radv_pipeline_link_shaders(device, tcs_stage->nir, tes_stage->nir, pipeline_key);
+
+   nir_linked_io_var_info tcs2tes =
+      nir_assign_linked_io_var_locations(tcs_stage->nir, tes_stage->nir);
+
+   tcs_stage->info.tcs.num_linked_outputs = tcs2tes.num_linked_io_vars;
+   tcs_stage->info.tcs.num_linked_patch_outputs = tcs2tes.num_linked_patch_io_vars;
+   tes_stage->info.tes.num_linked_inputs = tcs2tes.num_linked_io_vars;
+   tes_stage->info.tes.num_linked_patch_inputs = tcs2tes.num_linked_patch_io_vars;
 }
 
 static void
@@ -2811,6 +2841,18 @@ radv_pipeline_link_tes(const struct radv_device *device, struct radv_pipeline_st
    }
 
    radv_pipeline_link_shaders(device, tes_stage->nir, next_stage->nir, pipeline_key);
+
+   if (next_stage->nir->info.stage == MESA_SHADER_GEOMETRY) {
+      nir_linked_io_var_info tes2gs =
+         nir_assign_linked_io_var_locations(tes_stage->nir, next_stage->nir);
+
+      tes_stage->info.tes.num_linked_outputs = tes2gs.num_linked_io_vars;
+      next_stage->info.gs.num_linked_inputs = tes2gs.num_linked_io_vars;
+   } else {
+      nir_foreach_shader_out_variable(var, tes_stage->nir) {
+         var->data.driver_location = var->data.location;
+      }
+   }
 }
 
 static void
@@ -2822,6 +2864,10 @@ radv_pipeline_link_gs(const struct radv_device *device, struct radv_pipeline_sta
    assert(fs_stage->nir->info.stage == MESA_SHADER_FRAGMENT);
 
    radv_pipeline_link_shaders(device, gs_stage->nir, fs_stage->nir, pipeline_key);
+
+   nir_foreach_shader_out_variable(var, gs_stage->nir) {
+      var->data.driver_location = var->data.location;
+   }
 }
 
 static void
@@ -2854,6 +2900,13 @@ radv_pipeline_link_mesh(const struct radv_device *device, struct radv_pipeline_s
    }
 
    radv_pipeline_link_shaders(device, mesh_stage->nir, fs_stage->nir, pipeline_key);
+
+   /* ac_nir_lower_ngg ignores driver locations for mesh shaders, but set them to all zero just to
+    * be on the safe side.
+    */
+   nir_foreach_shader_out_variable(var, mesh_stage->nir) {
+      var->data.driver_location = 0;
+   }
 }
 
 static void
@@ -2890,6 +2943,10 @@ radv_pipeline_link_fs(struct radv_pipeline_stage *fs_stage,
       NIR_PASS(_, fs_stage->nir, nir_remove_dead_variables, nir_var_shader_temp, NULL);
       NIR_PASS(_, fs_stage->nir, nir_opt_dce);
    }
+
+   nir_foreach_shader_out_variable(var, fs_stage->nir) {
+      var->data.driver_location = var->data.location + var->data.index;
+   }
 }
 
 static void
@@ -2936,72 +2993,6 @@ radv_graphics_pipeline_link(const struct radv_pipeline *pipeline,
    }
 }
 
-static void
-radv_set_driver_locations(struct radv_pipeline *pipeline, struct radv_pipeline_stage *stages,
-                          gl_shader_stage last_vgt_api_stage)
-{
-   if (stages[MESA_SHADER_FRAGMENT].nir) {
-      nir_foreach_shader_out_variable(var, stages[MESA_SHADER_FRAGMENT].nir)
-      {
-         var->data.driver_location = var->data.location + var->data.index;
-      }
-   }
-
-   if (stages[MESA_SHADER_MESH].nir) {
-      /* ac_nir_lower_ngg ignores driver locations for mesh shaders,
-       * but set them to all zero just to be on the safe side.
-       */
-      nir_foreach_shader_out_variable(var, stages[MESA_SHADER_MESH].nir) {
-         var->data.driver_location = 0;
-      }
-      return;
-   }
-
-   if (!stages[MESA_SHADER_VERTEX].nir)
-      return;
-
-   bool has_tess = stages[MESA_SHADER_TESS_CTRL].nir;
-   bool has_gs = stages[MESA_SHADER_GEOMETRY].nir;
-
-   nir_foreach_shader_in_variable (var, stages[MESA_SHADER_VERTEX].nir) {
-      var->data.driver_location = var->data.location;
-   }
-
-   if (has_tess) {
-      nir_linked_io_var_info vs2tcs = nir_assign_linked_io_var_locations(
-         stages[MESA_SHADER_VERTEX].nir, stages[MESA_SHADER_TESS_CTRL].nir);
-      nir_linked_io_var_info tcs2tes = nir_assign_linked_io_var_locations(
-         stages[MESA_SHADER_TESS_CTRL].nir, stages[MESA_SHADER_TESS_EVAL].nir);
-
-      stages[MESA_SHADER_VERTEX].info.vs.num_linked_outputs = vs2tcs.num_linked_io_vars;
-      stages[MESA_SHADER_TESS_CTRL].info.tcs.num_linked_inputs = vs2tcs.num_linked_io_vars;
-      stages[MESA_SHADER_TESS_CTRL].info.tcs.num_linked_outputs = tcs2tes.num_linked_io_vars;
-      stages[MESA_SHADER_TESS_CTRL].info.tcs.num_linked_patch_outputs = tcs2tes.num_linked_patch_io_vars;
-      stages[MESA_SHADER_TESS_EVAL].info.tes.num_linked_inputs = tcs2tes.num_linked_io_vars;
-      stages[MESA_SHADER_TESS_EVAL].info.tes.num_linked_patch_inputs = tcs2tes.num_linked_patch_io_vars;
-
-      if (has_gs) {
-         nir_linked_io_var_info tes2gs = nir_assign_linked_io_var_locations(
-            stages[MESA_SHADER_TESS_EVAL].nir, stages[MESA_SHADER_GEOMETRY].nir);
-
-         stages[MESA_SHADER_TESS_EVAL].info.tes.num_linked_outputs = tes2gs.num_linked_io_vars;
-         stages[MESA_SHADER_GEOMETRY].info.gs.num_linked_inputs = tes2gs.num_linked_io_vars;
-      }
-   } else if (has_gs) {
-      nir_linked_io_var_info vs2gs = nir_assign_linked_io_var_locations(
-         stages[MESA_SHADER_VERTEX].nir, stages[MESA_SHADER_GEOMETRY].nir);
-
-      stages[MESA_SHADER_VERTEX].info.vs.num_linked_outputs = vs2gs.num_linked_io_vars;
-      stages[MESA_SHADER_GEOMETRY].info.gs.num_linked_inputs = vs2gs.num_linked_io_vars;
-   }
-
-   assert(last_vgt_api_stage != MESA_SHADER_NONE);
-   nir_foreach_shader_out_variable(var, stages[last_vgt_api_stage].nir)
-   {
-      var->data.driver_location = var->data.location;
-   }
-}
-
 static struct radv_pipeline_key
 radv_generate_pipeline_key(const struct radv_pipeline *pipeline, VkPipelineCreateFlags flags)
 {
@@ -4717,7 +4708,6 @@ radv_create_shaders(struct radv_pipeline *pipeline, struct radv_pipeline_layout
    }
 
    radv_graphics_pipeline_link(pipeline, pipeline_key, stages);
-   radv_set_driver_locations(pipeline, stages, *last_vgt_api_stage);
 
    for (int i = 0; i < MESA_VULKAN_SHADER_STAGES; ++i) {
       if (stages[i].nir) {