radv/rt: Handle no-null shader flags
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Mon, 12 Dec 2022 19:26:28 +0000 (20:26 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 13 Dec 2022 23:30:28 +0000 (23:30 +0000)
If those flags are set, we can assume that idx is not 0.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20286>

src/amd/vulkan/radv_rt_shader.c

index 5976715..4450620 100644 (file)
@@ -597,10 +597,15 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, SBT_CLOSEST_HIT_IDX);
 
                nir_ssa_def *should_return =
-                  nir_ior(&b_shader,
-                          nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags),
-                                        SpvRayFlagsSkipClosestHitShaderKHRMask),
-                          nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
+                  nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags),
+                                SpvRayFlagsSkipClosestHitShaderKHRMask);
+
+               if (!(vars->create_info->flags &
+                     VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
+                  should_return =
+                     nir_ior(&b_shader, should_return,
+                             nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
+               }
 
                /* should_return is set if we had a hit but we won't be calling the closest hit
                 * shader and hence need to return immediately to the calling shader. */
@@ -619,10 +624,15 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                nir_ssa_def *miss_index = nir_load_var(&b_shader, vars->miss_index);
                load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, SBT_GENERAL_IDX);
 
-               /* In case of a NULL miss shader, do nothing and just return. */
-               nir_push_if(&b_shader, nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
-               insert_rt_return(&b_shader, vars);
-               nir_pop_if(&b_shader, NULL);
+               if (!(vars->create_info->flags &
+                     VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
+                  /* In case of a NULL miss shader, do nothing and just return. */
+                  nir_push_if(&b_shader,
+                              nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
+                  insert_rt_return(&b_shader, vars);
+                  nir_pop_if(&b_shader, NULL);
+               }
+
                break;
             }
             default:
@@ -1078,7 +1088,9 @@ visit_any_hit_shaders(struct radv_device *device,
 {
    nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
 
-   nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
+   if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
+      nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
+
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
@@ -1099,7 +1111,9 @@ visit_any_hit_shaders(struct radv_device *device,
       vars->stage_idx = shader_id;
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
    }
-   nir_pop_if(b, NULL);
+
+   if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
+      nir_pop_if(b, NULL);
 }
 
 struct traversal_data {
@@ -1206,7 +1220,10 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
    nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
 
-   nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
+   if (!(data->vars->create_info->flags &
+         VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
+      nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
+
    for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &data->createInfo->pGroups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
@@ -1238,7 +1255,10 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
       inner_vars.stage_idx = shader_id;
       insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
    }
-   nir_pop_if(b, NULL);
+
+   if (!(data->vars->create_info->flags &
+         VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
+      nir_pop_if(b, NULL);
 
    nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
    {