intel/fs: fix metadata preserve on trace_ray intrinsic
authorLionel Landwerlin <lionel.g.landwerlin@intel.com>
Tue, 12 Apr 2022 18:59:58 +0000 (21:59 +0300)
committerMarge Bot <emma+marge@anholt.net>
Wed, 13 Apr 2022 11:24:49 +0000 (11:24 +0000)
c78be5da300 ("intel/fs: lower ray query intrinsics") introduced a
helper function using nir_(push|pop)_if which invalidated dominance &
block_index for the replacement of nir_intrinsic_rt_trace_ray.

We can still keep dominance/block_index metadata for the lowering of
nir_intrinsic_rt_execute_callable though.

This change uses 2 different lowering function with correct metadata
preservation.

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Fixes: c78be5da300 ("intel/fs: lower ray query intrinsics")
Reviewed-by: Marcin Ĺšlusarz <marcin.slusarz@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15910>

src/intel/compiler/brw_nir_lower_shader_calls.c

index fa04851059029c08869fef9f489882ebc46f1e2b..23a09c544730f54e780ee3062e4b410441f36b6c 100644 (file)
@@ -125,7 +125,7 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
 }
 
 static bool
-lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
+lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data)
 {
    if (instr->type != nir_instr_type_intrinsic)
       return false;
@@ -134,117 +134,130 @@ lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
     * brw_nir_lower_rt_intrinsics()
     */
    nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+   if (call->intrinsic != nir_intrinsic_rt_trace_ray)
+      return false;
 
-   switch (call->intrinsic) {
-   case nir_intrinsic_rt_trace_ray: {
-      b->cursor = nir_instr_remove(instr);
-
-      store_resume_addr(b, call);
-
-      nir_ssa_def *as_addr = call->src[0].ssa;
-      nir_ssa_def *ray_flags = call->src[1].ssa;
-      /* From the SPIR-V spec:
-       *
-       *    "Only the 8 least-significant bits of Cull Mask are used by this
-       *    instruction - other bits are ignored.
-       *
-       *    Only the 4 least-significant bits of SBT Offset and SBT Stride are
-       *    used by this instruction - other bits are ignored.
-       *
-       *    Only the 16 least-significant bits of Miss Index are used by this
-       *    instruction - other bits are ignored."
-       */
-      nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
-      nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
-      nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
-      nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
-      nir_ssa_def *ray_orig = call->src[6].ssa;
-      nir_ssa_def *ray_t_min = call->src[7].ssa;
-      nir_ssa_def *ray_dir = call->src[8].ssa;
-      nir_ssa_def *ray_t_max = call->src[9].ssa;
-
-      nir_ssa_def *root_node_ptr =
-         brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
-
-      /* The hardware packet requires an address to the first element of the
-       * hit SBT.
-       *
-       * In order to calculate this, we must multiply the "SBT Offset"
-       * provided to OpTraceRay by the SBT stride provided for the hit SBT in
-       * the call to vkCmdTraceRay() and add that to the base address of the
-       * hit SBT. This stride is not to be confused with the "SBT Stride"
-       * provided to OpTraceRay which is in units of this stride. It's a
-       * rather terrible overload of the word "stride". The hardware docs
-       * calls the SPIR-V stride value the "shader index multiplier" which is
-       * a much more sane name.
-       */
-      nir_ssa_def *hit_sbt_stride_B =
-         nir_load_ray_hit_sbt_stride_intel(b);
-      nir_ssa_def *hit_sbt_offset_B =
-         nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
-      nir_ssa_def *hit_sbt_addr =
-         nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
-                     nir_u2u64(b, hit_sbt_offset_B));
-
-      /* The hardware packet takes an address to the miss BSR. */
-      nir_ssa_def *miss_sbt_stride_B =
-         nir_load_ray_miss_sbt_stride_intel(b);
-      nir_ssa_def *miss_sbt_offset_B =
-         nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
-      nir_ssa_def *miss_sbt_addr =
-         nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
-                     nir_u2u64(b, miss_sbt_offset_B));
-
-      struct brw_nir_rt_mem_ray_defs ray_defs = {
-         .root_node_ptr = root_node_ptr,
-         .ray_flags = nir_u2u16(b, ray_flags),
-         .ray_mask = cull_mask,
-         .hit_group_sr_base_ptr = hit_sbt_addr,
-         .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
-         .miss_sr_ptr = miss_sbt_addr,
-         .orig = ray_orig,
-         .t_near = ray_t_min,
-         .dir = ray_dir,
-         .t_far = ray_t_max,
-         .shader_index_multiplier = sbt_stride,
-      };
-      brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
-      nir_trace_ray_intel(b,
-                          nir_load_btd_global_arg_addr_intel(b),
-                          nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
-                          nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
-                          .synchronous = false);
-      return true;
-   }
+   b->cursor = nir_instr_remove(instr);
 
-   case nir_intrinsic_rt_execute_callable: {
-      b->cursor = nir_instr_remove(instr);
+   store_resume_addr(b, call);
 
-      store_resume_addr(b, call);
+   nir_ssa_def *as_addr = call->src[0].ssa;
+   nir_ssa_def *ray_flags = call->src[1].ssa;
+   /* From the SPIR-V spec:
+    *
+    *    "Only the 8 least-significant bits of Cull Mask are used by this
+    *    instruction - other bits are ignored.
+    *
+    *    Only the 4 least-significant bits of SBT Offset and SBT Stride are
+    *    used by this instruction - other bits are ignored.
+    *
+    *    Only the 16 least-significant bits of Miss Index are used by this
+    *    instruction - other bits are ignored."
+    */
+   nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
+   nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
+   nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
+   nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
+   nir_ssa_def *ray_orig = call->src[6].ssa;
+   nir_ssa_def *ray_t_min = call->src[7].ssa;
+   nir_ssa_def *ray_dir = call->src[8].ssa;
+   nir_ssa_def *ray_t_max = call->src[9].ssa;
+
+   nir_ssa_def *root_node_ptr =
+      brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
+
+   /* The hardware packet requires an address to the first element of the
+    * hit SBT.
+    *
+    * In order to calculate this, we must multiply the "SBT Offset"
+    * provided to OpTraceRay by the SBT stride provided for the hit SBT in
+    * the call to vkCmdTraceRay() and add that to the base address of the
+    * hit SBT. This stride is not to be confused with the "SBT Stride"
+    * provided to OpTraceRay which is in units of this stride. It's a
+    * rather terrible overload of the word "stride". The hardware docs
+    * calls the SPIR-V stride value the "shader index multiplier" which is
+    * a much more sane name.
+    */
+   nir_ssa_def *hit_sbt_stride_B =
+      nir_load_ray_hit_sbt_stride_intel(b);
+   nir_ssa_def *hit_sbt_offset_B =
+      nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
+   nir_ssa_def *hit_sbt_addr =
+      nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
+                  nir_u2u64(b, hit_sbt_offset_B));
+
+   /* The hardware packet takes an address to the miss BSR. */
+   nir_ssa_def *miss_sbt_stride_B =
+      nir_load_ray_miss_sbt_stride_intel(b);
+   nir_ssa_def *miss_sbt_offset_B =
+      nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
+   nir_ssa_def *miss_sbt_addr =
+      nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
+                  nir_u2u64(b, miss_sbt_offset_B));
+
+   struct brw_nir_rt_mem_ray_defs ray_defs = {
+      .root_node_ptr = root_node_ptr,
+      .ray_flags = nir_u2u16(b, ray_flags),
+      .ray_mask = cull_mask,
+      .hit_group_sr_base_ptr = hit_sbt_addr,
+      .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
+      .miss_sr_ptr = miss_sbt_addr,
+      .orig = ray_orig,
+      .t_near = ray_t_min,
+      .dir = ray_dir,
+      .t_far = ray_t_max,
+      .shader_index_multiplier = sbt_stride,
+   };
+   brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
+
+   nir_trace_ray_intel(b,
+                       nir_load_btd_global_arg_addr_intel(b),
+                       nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
+                       nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
+                       .synchronous = false);
+   return true;
+}
 
-      nir_ssa_def *sbt_offset32 =
-         nir_imul(b, call->src[0].ssa,
-                     nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
-      nir_ssa_def *sbt_addr =
-         nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
-                     nir_u2u64(b, sbt_offset32));
-      brw_nir_btd_spawn(b, sbt_addr);
-      return true;
-   }
+static bool
+lower_shader_call_instr(struct nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
 
-   default:
+   /* Leave nir_intrinsic_rt_resume to be lowered by
+    * brw_nir_lower_rt_intrinsics()
+    */
+   nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+   if (call->intrinsic != nir_intrinsic_rt_execute_callable)
       return false;
-   }
+
+   b->cursor = nir_instr_remove(instr);
+
+   store_resume_addr(b, call);
+
+   nir_ssa_def *sbt_offset32 =
+      nir_imul(b, call->src[0].ssa,
+               nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
+   nir_ssa_def *sbt_addr =
+      nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
+               nir_u2u64(b, sbt_offset32));
+   brw_nir_btd_spawn(b, sbt_addr);
+   return true;
 }
 
 bool
 brw_nir_lower_shader_calls(nir_shader *shader)
 {
-   return nir_shader_instructions_pass(shader,
-                                       lower_shader_calls_instr,
-                                       nir_metadata_block_index |
-                                       nir_metadata_dominance,
-                                       NULL);
+   return
+      nir_shader_instructions_pass(shader,
+                                   lower_shader_trace_ray_instr,
+                                   nir_metadata_none,
+                                   NULL) |
+      nir_shader_instructions_pass(shader,
+                                   lower_shader_call_instr,
+                                   nir_metadata_block_index |
+                                   nir_metadata_dominance,
+                                   NULL);
 }
 
 /** Creates a trivial return shader