nir_btd_stack_push_intel(b, offset);
}
-bool
-brw_nir_lower_shader_calls(nir_shader *shader)
+static bool
+lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
{
- nir_function_impl *impl = nir_shader_get_entrypoint(shader);
- bool progress = false;
-
- nir_builder _b, *b = &_b;
- nir_builder_init(&_b, impl);
-
- nir_foreach_block_safe(block, impl) {
- nir_foreach_instr_safe(instr, block) {
- if (instr->type != nir_instr_type_intrinsic)
- continue;
-
- nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
- if (call->intrinsic != nir_intrinsic_rt_trace_ray &&
- call->intrinsic != nir_intrinsic_rt_execute_callable &&
- call->intrinsic != nir_intrinsic_rt_resume)
- continue;
-
- b->cursor = nir_before_instr(instr);
-
- progress = true;
-
- switch (call->intrinsic) {
- case nir_intrinsic_rt_trace_ray: {
- store_resume_addr(b, call);
-
- nir_ssa_def *as_addr = call->src[0].ssa;
- nir_ssa_def *ray_flags = call->src[1].ssa;
- /* From the SPIR-V spec:
- *
- * "Only the 8 least-significant bits of Cull Mask are used by
- * this instruction - other bits are ignored.
- *
- * Only the 4 least-significant bits of SBT Offset and SBT
- * Stride are used by this instruction - other bits are
- * ignored.
- *
- * Only the 16 least-significant bits of Miss Index are used by
- * this instruction - other bits are ignored."
- */
- nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
- nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
- nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
- nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
- nir_ssa_def *ray_orig = call->src[6].ssa;
- nir_ssa_def *ray_t_min = call->src[7].ssa;
- nir_ssa_def *ray_dir = call->src[8].ssa;
- nir_ssa_def *ray_t_max = call->src[9].ssa;
-
- /* The hardware packet takes the address to the root node in the
- * acceleration structure, not the acceleration structure itself.
- * To find that, we have to read the root node offset from the
- * acceleration structure which is the first QWord.
- */
- nir_ssa_def *root_node_ptr =
- nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
-
- /* The hardware packet requires an address to the first element of
- * the hit SBT.
- *
- * In order to calculate this, we must multiply the "SBT Offset"
- * provided to OpTraceRay by the SBT stride provided for the hit
- * SBT in the call to vkCmdTraceRay() and add that to the base
- * address of the hit SBT. This stride is not to be confused with
- * the "SBT Stride" provided to OpTraceRay which is in units of
- * this stride. It's a rather terrible overload of the word
- * "stride". The hardware docs calls the SPIR-V stride value the
- * "shader index multiplier" which is a much more sane name.
- */
- nir_ssa_def *hit_sbt_stride_B =
- nir_load_ray_hit_sbt_stride_intel(b);
- nir_ssa_def *hit_sbt_offset_B =
- nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
- nir_ssa_def *hit_sbt_addr =
- nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
- nir_u2u64(b, hit_sbt_offset_B));
-
- /* The hardware packet takes an address to the miss BSR. */
- nir_ssa_def *miss_sbt_stride_B =
- nir_load_ray_miss_sbt_stride_intel(b);
- nir_ssa_def *miss_sbt_offset_B =
- nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
- nir_ssa_def *miss_sbt_addr =
- nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
- nir_u2u64(b, miss_sbt_offset_B));
+ if (instr->type != nir_instr_type_intrinsic)
+ return false;
- struct brw_nir_rt_mem_ray_defs ray_defs = {
- .root_node_ptr = root_node_ptr,
- .ray_flags = nir_u2u16(b, ray_flags),
- .ray_mask = cull_mask,
- .hit_group_sr_base_ptr = hit_sbt_addr,
- .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
- .miss_sr_ptr = miss_sbt_addr,
- .orig = ray_orig,
- .t_near = ray_t_min,
- .dir = ray_dir,
- .t_far = ray_t_max,
- .shader_index_multiplier = sbt_stride,
- };
- brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
- nir_trace_ray_initial_intel(b);
- break;
- }
-
- case nir_intrinsic_rt_execute_callable: {
- store_resume_addr(b, call);
-
- nir_ssa_def *sbt_offset32 =
- nir_imul(b, call->src[0].ssa,
- nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
- nir_ssa_def *sbt_addr =
- nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
- nir_u2u64(b, sbt_offset32));
- brw_nir_btd_spawn(b, sbt_addr);
- break;
- }
-
- default:
- unreachable("Invalid intrinsic");
- }
-
- nir_instr_remove(&call->instr);
- }
+ /* Leave nir_intrinsic_rt_resume to be lowered by
+ * brw_nir_lower_rt_intrinsics()
+ */
+ nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+
+ switch (call->intrinsic) {
+ case nir_intrinsic_rt_trace_ray: {
+ store_resume_addr(b, call);
+
+ nir_ssa_def *as_addr = call->src[0].ssa;
+ nir_ssa_def *ray_flags = call->src[1].ssa;
+ /* From the SPIR-V spec:
+ *
+ * "Only the 8 least-significant bits of Cull Mask are used by this
+ * instruction - other bits are ignored.
+ *
+ * Only the 4 least-significant bits of SBT Offset and SBT Stride are
+ * used by this instruction - other bits are ignored.
+ *
+ * Only the 16 least-significant bits of Miss Index are used by this
+ * instruction - other bits are ignored."
+ */
+ nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
+ nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
+ nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
+ nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
+ nir_ssa_def *ray_orig = call->src[6].ssa;
+ nir_ssa_def *ray_t_min = call->src[7].ssa;
+ nir_ssa_def *ray_dir = call->src[8].ssa;
+ nir_ssa_def *ray_t_max = call->src[9].ssa;
+
+ /* The hardware packet takes the address to the root node in the
+ * acceleration structure, not the acceleration structure itself. To
+ * find that, we have to read the root node offset from the acceleration
+ * structure which is the first QWord.
+ */
+ nir_ssa_def *root_node_ptr =
+ nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
+
+ /* The hardware packet requires an address to the first element of the
+ * hit SBT.
+ *
+ * In order to calculate this, we must multiply the "SBT Offset"
+ * provided to OpTraceRay by the SBT stride provided for the hit SBT in
+ * the call to vkCmdTraceRay() and add that to the base address of the
+ * hit SBT. This stride is not to be confused with the "SBT Stride"
+ * provided to OpTraceRay which is in units of this stride. It's a
+ * rather terrible overload of the word "stride". The hardware docs
+ * calls the SPIR-V stride value the "shader index multiplier" which is
+ * a much more sane name.
+ */
+ nir_ssa_def *hit_sbt_stride_B =
+ nir_load_ray_hit_sbt_stride_intel(b);
+ nir_ssa_def *hit_sbt_offset_B =
+ nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
+ nir_ssa_def *hit_sbt_addr =
+ nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
+ nir_u2u64(b, hit_sbt_offset_B));
+
+ /* The hardware packet takes an address to the miss BSR. */
+ nir_ssa_def *miss_sbt_stride_B =
+ nir_load_ray_miss_sbt_stride_intel(b);
+ nir_ssa_def *miss_sbt_offset_B =
+ nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
+ nir_ssa_def *miss_sbt_addr =
+ nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
+ nir_u2u64(b, miss_sbt_offset_B));
+
+ struct brw_nir_rt_mem_ray_defs ray_defs = {
+ .root_node_ptr = root_node_ptr,
+ .ray_flags = nir_u2u16(b, ray_flags),
+ .ray_mask = cull_mask,
+ .hit_group_sr_base_ptr = hit_sbt_addr,
+ .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
+ .miss_sr_ptr = miss_sbt_addr,
+ .orig = ray_orig,
+ .t_near = ray_t_min,
+ .dir = ray_dir,
+ .t_far = ray_t_max,
+ .shader_index_multiplier = sbt_stride,
+ };
+ brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
+ nir_trace_ray_initial_intel(b);
+ return true;
}
- nir_foreach_block_safe(block, impl) {
- nir_foreach_instr_safe(instr, block) {
- if (instr->type != nir_instr_type_intrinsic)
- continue;
-
+ case nir_intrinsic_rt_execute_callable: {
+ store_resume_addr(b, call);
+
+ nir_ssa_def *sbt_offset32 =
+ nir_imul(b, call->src[0].ssa,
+ nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
+ nir_ssa_def *sbt_addr =
+ nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
+ nir_u2u64(b, sbt_offset32));
+ brw_nir_btd_spawn(b, sbt_addr);
+ return true;
+ }
- }
+ default:
+ return false;
}
+}
- return progress;
+bool
+brw_nir_lower_shader_calls(nir_shader *shader)
+{
+ return nir_shader_instructions_pass(shader,
+ lower_shader_calls_instr,
+ nir_metadata_block_index |
+ nir_metadata_dominance,
+ NULL);
}
/** Creates a trivial return shader