intel/fs: fix shader call lowering pass
authorLionel Landwerlin <lionel.g.landwerlin@intel.com>
Tue, 29 Jun 2021 09:40:39 +0000 (12:40 +0300)
committerMarge Bot <emma+marge@anholt.net>
Mon, 22 Nov 2021 08:17:26 +0000 (08:17 +0000)
Now that we removed the intel intrinsic and just use the generic one,
we can skip it in the intel call lowering pass and just deal with it
in the intel rt intrinsic lowering.

v2: rewrite with nir_shader_instructions_pass() (Jason)

v3: handle everything in switch (Jason)

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Fixes: 423c47de991643 ("nir: drop the btd_resume_intel intrinsic")
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12113>

src/intel/compiler/brw_nir_lower_shader_calls.c

index 4f88f10..38c4e0a 100644 (file)
@@ -124,143 +124,124 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
    nir_btd_stack_push_intel(b, offset);
 }
 
-bool
-brw_nir_lower_shader_calls(nir_shader *shader)
+static bool
+lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
 {
-   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-   bool progress = false;
-
-   nir_builder _b, *b = &_b;
-   nir_builder_init(&_b, impl);
-
-   nir_foreach_block_safe(block, impl) {
-      nir_foreach_instr_safe(instr, block) {
-         if (instr->type != nir_instr_type_intrinsic)
-            continue;
-
-         nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
-         if (call->intrinsic != nir_intrinsic_rt_trace_ray &&
-             call->intrinsic != nir_intrinsic_rt_execute_callable &&
-             call->intrinsic != nir_intrinsic_rt_resume)
-            continue;
-
-         b->cursor = nir_before_instr(instr);
-
-         progress = true;
-
-         switch (call->intrinsic) {
-         case nir_intrinsic_rt_trace_ray: {
-            store_resume_addr(b, call);
-
-            nir_ssa_def *as_addr = call->src[0].ssa;
-            nir_ssa_def *ray_flags = call->src[1].ssa;
-            /* From the SPIR-V spec:
-             *
-             *    "Only the 8 least-significant bits of Cull Mask are used by
-             *    this instruction - other bits are ignored.
-             *
-             *    Only the 4 least-significant bits of SBT Offset and SBT
-             *    Stride are used by this instruction - other bits are
-             *    ignored.
-             *
-             *    Only the 16 least-significant bits of Miss Index are used by
-             *    this instruction - other bits are ignored."
-             */
-            nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
-            nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
-            nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
-            nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
-            nir_ssa_def *ray_orig = call->src[6].ssa;
-            nir_ssa_def *ray_t_min = call->src[7].ssa;
-            nir_ssa_def *ray_dir = call->src[8].ssa;
-            nir_ssa_def *ray_t_max = call->src[9].ssa;
-
-            /* The hardware packet takes the address to the root node in the
-             * acceleration structure, not the acceleration structure itself.
-             * To find that, we have to read the root node offset from the
-             * acceleration structure which is the first QWord.
-             */
-            nir_ssa_def *root_node_ptr =
-               nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
-
-            /* The hardware packet requires an address to the first element of
-             * the hit SBT.
-             *
-             * In order to calculate this, we must multiply the "SBT Offset"
-             * provided to OpTraceRay by the SBT stride provided for the hit
-             * SBT in the call to vkCmdTraceRay() and add that to the base
-             * address of the hit SBT.  This stride is not to be confused with
-             * the "SBT Stride" provided to OpTraceRay which is in units of
-             * this stride.  It's a rather terrible overload of the word
-             * "stride".  The hardware docs calls the SPIR-V stride value the
-             * "shader index multiplier" which is a much more sane name.
-             */
-            nir_ssa_def *hit_sbt_stride_B =
-               nir_load_ray_hit_sbt_stride_intel(b);
-            nir_ssa_def *hit_sbt_offset_B =
-               nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
-            nir_ssa_def *hit_sbt_addr =
-               nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
-                           nir_u2u64(b, hit_sbt_offset_B));
-
-            /* The hardware packet takes an address to the miss BSR. */
-            nir_ssa_def *miss_sbt_stride_B =
-               nir_load_ray_miss_sbt_stride_intel(b);
-            nir_ssa_def *miss_sbt_offset_B =
-               nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
-            nir_ssa_def *miss_sbt_addr =
-               nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
-                           nir_u2u64(b, miss_sbt_offset_B));
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
 
-            struct brw_nir_rt_mem_ray_defs ray_defs = {
-               .root_node_ptr = root_node_ptr,
-               .ray_flags = nir_u2u16(b, ray_flags),
-               .ray_mask = cull_mask,
-               .hit_group_sr_base_ptr = hit_sbt_addr,
-               .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
-               .miss_sr_ptr = miss_sbt_addr,
-               .orig = ray_orig,
-               .t_near = ray_t_min,
-               .dir = ray_dir,
-               .t_far = ray_t_max,
-               .shader_index_multiplier = sbt_stride,
-            };
-            brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
-            nir_trace_ray_initial_intel(b);
-            break;
-         }
-
-         case nir_intrinsic_rt_execute_callable: {
-            store_resume_addr(b, call);
-
-            nir_ssa_def *sbt_offset32 =
-               nir_imul(b, call->src[0].ssa,
-                        nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
-            nir_ssa_def *sbt_addr =
-               nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
-                           nir_u2u64(b, sbt_offset32));
-            brw_nir_btd_spawn(b, sbt_addr);
-            break;
-         }
-
-         default:
-            unreachable("Invalid intrinsic");
-         }
-
-         nir_instr_remove(&call->instr);
-      }
+   /* Leave nir_intrinsic_rt_resume to be lowered by
+    * brw_nir_lower_rt_intrinsics()
+    */
+   nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+
+   switch (call->intrinsic) {
+   case nir_intrinsic_rt_trace_ray: {
+      store_resume_addr(b, call);
+
+      nir_ssa_def *as_addr = call->src[0].ssa;
+      nir_ssa_def *ray_flags = call->src[1].ssa;
+      /* From the SPIR-V spec:
+       *
+       *    "Only the 8 least-significant bits of Cull Mask are used by this
+       *    instruction - other bits are ignored.
+       *
+       *    Only the 4 least-significant bits of SBT Offset and SBT Stride are
+       *    used by this instruction - other bits are ignored.
+       *
+       *    Only the 16 least-significant bits of Miss Index are used by this
+       *    instruction - other bits are ignored."
+       */
+      nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
+      nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
+      nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
+      nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
+      nir_ssa_def *ray_orig = call->src[6].ssa;
+      nir_ssa_def *ray_t_min = call->src[7].ssa;
+      nir_ssa_def *ray_dir = call->src[8].ssa;
+      nir_ssa_def *ray_t_max = call->src[9].ssa;
+
+      /* The hardware packet takes the address to the root node in the
+       * acceleration structure, not the acceleration structure itself. To
+       * find that, we have to read the root node offset from the acceleration
+       * structure which is the first QWord.
+       */
+      nir_ssa_def *root_node_ptr =
+         nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
+
+      /* The hardware packet requires an address to the first element of the
+       * hit SBT.
+       *
+       * In order to calculate this, we must multiply the "SBT Offset"
+       * provided to OpTraceRay by the SBT stride provided for the hit SBT in
+       * the call to vkCmdTraceRay() and add that to the base address of the
+       * hit SBT. This stride is not to be confused with the "SBT Stride"
+       * provided to OpTraceRay which is in units of this stride. It's a
+       * rather terrible overload of the word "stride". The hardware docs
+       * calls the SPIR-V stride value the "shader index multiplier" which is
+       * a much more sane name.
+       */
+      nir_ssa_def *hit_sbt_stride_B =
+         nir_load_ray_hit_sbt_stride_intel(b);
+      nir_ssa_def *hit_sbt_offset_B =
+         nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
+      nir_ssa_def *hit_sbt_addr =
+         nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
+                     nir_u2u64(b, hit_sbt_offset_B));
+
+      /* The hardware packet takes an address to the miss BSR. */
+      nir_ssa_def *miss_sbt_stride_B =
+         nir_load_ray_miss_sbt_stride_intel(b);
+      nir_ssa_def *miss_sbt_offset_B =
+         nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
+      nir_ssa_def *miss_sbt_addr =
+         nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
+                     nir_u2u64(b, miss_sbt_offset_B));
+
+      struct brw_nir_rt_mem_ray_defs ray_defs = {
+         .root_node_ptr = root_node_ptr,
+         .ray_flags = nir_u2u16(b, ray_flags),
+         .ray_mask = cull_mask,
+         .hit_group_sr_base_ptr = hit_sbt_addr,
+         .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
+         .miss_sr_ptr = miss_sbt_addr,
+         .orig = ray_orig,
+         .t_near = ray_t_min,
+         .dir = ray_dir,
+         .t_far = ray_t_max,
+         .shader_index_multiplier = sbt_stride,
+      };
+      brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
+      nir_trace_ray_initial_intel(b);
+      return true;
    }
 
-   nir_foreach_block_safe(block, impl) {
-      nir_foreach_instr_safe(instr, block) {
-         if (instr->type != nir_instr_type_intrinsic)
-            continue;
-
+   case nir_intrinsic_rt_execute_callable: {
+      store_resume_addr(b, call);
+
+      nir_ssa_def *sbt_offset32 =
+         nir_imul(b, call->src[0].ssa,
+                     nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
+      nir_ssa_def *sbt_addr =
+         nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
+                     nir_u2u64(b, sbt_offset32));
+      brw_nir_btd_spawn(b, sbt_addr);
+      return true;
+   }
 
-      }
+   default:
+      return false;
    }
+}
 
-   return progress;
+bool
+brw_nir_lower_shader_calls(nir_shader *shader)
+{
+   return nir_shader_instructions_pass(shader,
+                                       lower_shader_calls_instr,
+                                       nir_metadata_block_index |
+                                       nir_metadata_dominance,
+                                       NULL);
 }
 
 /** Creates a trivial return shader