anv: support VK_PIPELINE_CREATE_RAY_TRACING_SKIP_*
authorIván Briano <ivan.briano@intel.com>
Mon, 26 Sep 2022 20:26:36 +0000 (13:26 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 20 Oct 2022 00:03:55 +0000 (00:03 +0000)
VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR and
VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR, when specified,
make TraceRay behave as if the corresponding shader flags were set, but
without affecting the value of IncomingRayFlags in shaders.

v2 (Lionel): Improve comments

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19152>

src/intel/compiler/brw_compiler.h
src/intel/compiler/brw_nir_lower_rt_intrinsics.c
src/intel/compiler/brw_nir_lower_shader_calls.c
src/intel/compiler/brw_nir_rt.h
src/intel/vulkan/anv_pipeline.c

index 6a78965..ca9a797 100644 (file)
@@ -526,6 +526,12 @@ struct brw_cs_prog_key {
 
 struct brw_bs_prog_key {
    struct brw_base_prog_key base;
+
+   /* Represents enum enum brw_rt_ray_flags values given at pipeline creation
+    * to be combined with ray_flags handed to the traceRayEXT() calls by the
+    * shader.
+    */
+   uint32_t pipeline_ray_flags;
 };
 
 struct brw_ff_gs_prog_key {
index 4d793d0..fa46168 100644 (file)
@@ -221,7 +221,15 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
          }
 
          case nir_intrinsic_load_ray_flags:
-            sysval = nir_u2u32(b, world_ray_in.ray_flags);
+            /* We need to fetch the original ray flags we stored in the
+             * leaf pointer, because the actual ray flags we get here
+             * will include any flags passed on the pipeline at creation
+             * time, and the spec for IncomingRayFlagsKHR says:
+             *   Setting pipeline flags on the raytracing pipeline must not
+             *   cause any corresponding flags to be set in variables with
+             *   this decoration.
+             */
+            sysval = nir_u2u32(b, world_ray_in.inst_leaf_ptr);
             break;
 
          case nir_intrinsic_load_ray_geometry_index: {
index afb0fae..5bcc1f1 100644 (file)
@@ -153,6 +153,8 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
 static bool
 lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data)
 {
+   struct brw_bs_prog_key *key = data;
+
    if (instr->type != nir_instr_type_intrinsic)
       return false;
 
@@ -223,7 +225,10 @@ lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data
 
    struct brw_nir_rt_mem_ray_defs ray_defs = {
       .root_node_ptr = root_node_ptr,
-      .ray_flags = nir_u2u16(b, ray_flags),
+      /* Combine the shader value given to traceRayEXT() with the pipeline
+       * creation value VkPipelineCreateFlags.
+       */
+      .ray_flags = nir_ior_imm(b, nir_u2u16(b, ray_flags), key->pipeline_ray_flags),
       .ray_mask = cull_mask,
       .hit_group_sr_base_ptr = hit_sbt_addr,
       .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
@@ -233,6 +238,13 @@ lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data
       .dir = ray_dir,
       .t_far = ray_t_max,
       .shader_index_multiplier = sbt_stride,
+      /* The instance leaf pointer is unused in the top level BVH traversal
+       * since we always start from the root node. We can reuse that field to
+       * store the ray_flags handed to traceRayEXT(). This will be reloaded
+       * when the shader accesses gl_IncomingRayFlagsEXT (see
+       * nir_intrinsic_load_ray_flags brw_nir_lower_rt_intrinsic.c)
+       */
+      .inst_leaf_ptr = nir_u2u64(b, ray_flags),
    };
    brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
 
@@ -272,13 +284,13 @@ lower_shader_call_instr(struct nir_builder *b, nir_instr *instr, void *data)
 }
 
 bool
-brw_nir_lower_shader_calls(nir_shader *shader)
+brw_nir_lower_shader_calls(nir_shader *shader, struct brw_bs_prog_key *key)
 {
    return
       nir_shader_instructions_pass(shader,
                                    lower_shader_trace_ray_instr,
                                    nir_metadata_none,
-                                   NULL) |
+                                   key) |
       nir_shader_instructions_pass(shader,
                                    lower_shader_call_instr,
                                    nir_metadata_block_index |
index a3c6461..4215d34 100644 (file)
@@ -54,7 +54,7 @@ bool brw_nir_lower_ray_queries(nir_shader *shader,
 
 void brw_nir_lower_shader_returns(nir_shader *shader);
 
-bool brw_nir_lower_shader_calls(nir_shader *shader);
+bool brw_nir_lower_shader_calls(nir_shader *shader, struct brw_bs_prog_key *key);
 
 void brw_nir_lower_rt_intrinsics(nir_shader *shader,
                                  const struct intel_device_info *devinfo);
index 135020a..f2cdeb2 100644 (file)
@@ -639,11 +639,14 @@ populate_cs_prog_key(const struct anv_device *device,
 static void
 populate_bs_prog_key(const struct anv_device *device,
                      bool robust_buffer_access,
+                     uint32_t ray_flags,
                      struct brw_bs_prog_key *key)
 {
    memset(key, 0, sizeof(*key));
 
    populate_base_prog_key(device, robust_buffer_access, &key->base);
+
+   key->pipeline_ray_flags = ray_flags;
 }
 
 struct anv_pipeline_stage {
@@ -2466,12 +2469,12 @@ compile_upload_rt_shader(struct anv_ray_tracing_pipeline *pipeline,
                nir_address_format_64bit_global,
                BRW_BTD_STACK_ALIGN,
                &resume_shaders, &num_resume_shaders, mem_ctx);
-      NIR_PASS(_, nir, brw_nir_lower_shader_calls);
+      NIR_PASS(_, nir, brw_nir_lower_shader_calls, &stage->key.bs);
       NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);
    }
 
    for (unsigned i = 0; i < num_resume_shaders; i++) {
-      NIR_PASS(_,resume_shaders[i], brw_nir_lower_shader_calls);
+      NIR_PASS(_,resume_shaders[i], brw_nir_lower_shader_calls, &stage->key.bs);
       NIR_PASS_V(resume_shaders[i], brw_nir_lower_rt_intrinsics, devinfo);
    }
 
@@ -2578,6 +2581,25 @@ anv_pipeline_compute_ray_tracing_stacks(struct anv_ray_tracing_pipeline *pipelin
    }
 }
 
+static enum brw_rt_ray_flags
+anv_pipeline_get_pipeline_ray_flags(VkPipelineCreateFlags flags)
+{
+   uint32_t ray_flags = 0;
+
+   const bool rt_skip_triangles =
+      flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR;
+   const bool rt_skip_aabbs =
+      flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR;
+   assert(!(rt_skip_triangles && rt_skip_aabbs));
+
+   if (rt_skip_triangles)
+      ray_flags |= BRW_RT_RAY_FLAG_SKIP_TRIANGLES;
+   else if (rt_skip_aabbs)
+      ray_flags |= BRW_RT_RAY_FLAG_SKIP_AABBS;
+
+   return ray_flags;
+}
+
 static struct anv_pipeline_stage *
 anv_pipeline_init_ray_tracing_stages(struct anv_ray_tracing_pipeline *pipeline,
                                      const VkRayTracingPipelineCreateInfoKHR *info,
@@ -2591,6 +2613,9 @@ anv_pipeline_init_ray_tracing_stages(struct anv_ray_tracing_pipeline *pipeline,
    struct anv_pipeline_stage *stages =
       rzalloc_array(pipeline_ctx, struct anv_pipeline_stage, info->stageCount);
 
+   enum brw_rt_ray_flags ray_flags =
+      anv_pipeline_get_pipeline_ray_flags(pipeline->base.flags);
+
    for (uint32_t i = 0; i < info->stageCount; i++) {
       const VkPipelineShaderStageCreateInfo *sinfo = &info->pStages[i];
       if (vk_pipeline_shader_stage_is_null(sinfo))
@@ -2611,6 +2636,7 @@ anv_pipeline_init_ray_tracing_stages(struct anv_ray_tracing_pipeline *pipeline,
 
       populate_bs_prog_key(pipeline->base.device,
                            pipeline->base.device->robust_buffer_access,
+                           ray_flags,
                            &stages[i].key.bs);
 
       vk_pipeline_hash_shader_stage(sinfo, stages[i].shader_sha1);