}
static bool
-lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
+lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
* brw_nir_lower_rt_intrinsics()
*/
nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+ if (call->intrinsic != nir_intrinsic_rt_trace_ray)
+ return false;
- switch (call->intrinsic) {
- case nir_intrinsic_rt_trace_ray: {
- b->cursor = nir_instr_remove(instr);
-
- store_resume_addr(b, call);
-
- nir_ssa_def *as_addr = call->src[0].ssa;
- nir_ssa_def *ray_flags = call->src[1].ssa;
- /* From the SPIR-V spec:
- *
- * "Only the 8 least-significant bits of Cull Mask are used by this
- * instruction - other bits are ignored.
- *
- * Only the 4 least-significant bits of SBT Offset and SBT Stride are
- * used by this instruction - other bits are ignored.
- *
- * Only the 16 least-significant bits of Miss Index are used by this
- * instruction - other bits are ignored."
- */
- nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
- nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
- nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
- nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
- nir_ssa_def *ray_orig = call->src[6].ssa;
- nir_ssa_def *ray_t_min = call->src[7].ssa;
- nir_ssa_def *ray_dir = call->src[8].ssa;
- nir_ssa_def *ray_t_max = call->src[9].ssa;
-
- nir_ssa_def *root_node_ptr =
- brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
-
- /* The hardware packet requires an address to the first element of the
- * hit SBT.
- *
- * In order to calculate this, we must multiply the "SBT Offset"
- * provided to OpTraceRay by the SBT stride provided for the hit SBT in
- * the call to vkCmdTraceRay() and add that to the base address of the
- * hit SBT. This stride is not to be confused with the "SBT Stride"
- * provided to OpTraceRay which is in units of this stride. It's a
- * rather terrible overload of the word "stride". The hardware docs
- * calls the SPIR-V stride value the "shader index multiplier" which is
- * a much more sane name.
- */
- nir_ssa_def *hit_sbt_stride_B =
- nir_load_ray_hit_sbt_stride_intel(b);
- nir_ssa_def *hit_sbt_offset_B =
- nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
- nir_ssa_def *hit_sbt_addr =
- nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
- nir_u2u64(b, hit_sbt_offset_B));
-
- /* The hardware packet takes an address to the miss BSR. */
- nir_ssa_def *miss_sbt_stride_B =
- nir_load_ray_miss_sbt_stride_intel(b);
- nir_ssa_def *miss_sbt_offset_B =
- nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
- nir_ssa_def *miss_sbt_addr =
- nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
- nir_u2u64(b, miss_sbt_offset_B));
-
- struct brw_nir_rt_mem_ray_defs ray_defs = {
- .root_node_ptr = root_node_ptr,
- .ray_flags = nir_u2u16(b, ray_flags),
- .ray_mask = cull_mask,
- .hit_group_sr_base_ptr = hit_sbt_addr,
- .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
- .miss_sr_ptr = miss_sbt_addr,
- .orig = ray_orig,
- .t_near = ray_t_min,
- .dir = ray_dir,
- .t_far = ray_t_max,
- .shader_index_multiplier = sbt_stride,
- };
- brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
- nir_trace_ray_intel(b,
- nir_load_btd_global_arg_addr_intel(b),
- nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
- nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
- .synchronous = false);
- return true;
- }
+ b->cursor = nir_instr_remove(instr);
- case nir_intrinsic_rt_execute_callable: {
- b->cursor = nir_instr_remove(instr);
+ store_resume_addr(b, call);
- store_resume_addr(b, call);
+ nir_ssa_def *as_addr = call->src[0].ssa;
+ nir_ssa_def *ray_flags = call->src[1].ssa;
+ /* From the SPIR-V spec:
+ *
+ * "Only the 8 least-significant bits of Cull Mask are used by this
+ * instruction - other bits are ignored.
+ *
+ * Only the 4 least-significant bits of SBT Offset and SBT Stride are
+ * used by this instruction - other bits are ignored.
+ *
+ * Only the 16 least-significant bits of Miss Index are used by this
+ * instruction - other bits are ignored."
+ */
+ nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
+ nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
+ nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
+ nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
+ nir_ssa_def *ray_orig = call->src[6].ssa;
+ nir_ssa_def *ray_t_min = call->src[7].ssa;
+ nir_ssa_def *ray_dir = call->src[8].ssa;
+ nir_ssa_def *ray_t_max = call->src[9].ssa;
+
+ nir_ssa_def *root_node_ptr =
+ brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
+
+ /* The hardware packet requires an address to the first element of the
+ * hit SBT.
+ *
+ * In order to calculate this, we must multiply the "SBT Offset"
+ * provided to OpTraceRay by the SBT stride provided for the hit SBT in
+ * the call to vkCmdTraceRay() and add that to the base address of the
+ * hit SBT. This stride is not to be confused with the "SBT Stride"
+ * provided to OpTraceRay which is in units of this stride. It's a
+ * rather terrible overload of the word "stride". The hardware docs
+ * calls the SPIR-V stride value the "shader index multiplier" which is
+ * a much more sane name.
+ */
+ nir_ssa_def *hit_sbt_stride_B =
+ nir_load_ray_hit_sbt_stride_intel(b);
+ nir_ssa_def *hit_sbt_offset_B =
+ nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
+ nir_ssa_def *hit_sbt_addr =
+ nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
+ nir_u2u64(b, hit_sbt_offset_B));
+
+ /* The hardware packet takes an address to the miss BSR. */
+ nir_ssa_def *miss_sbt_stride_B =
+ nir_load_ray_miss_sbt_stride_intel(b);
+ nir_ssa_def *miss_sbt_offset_B =
+ nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
+ nir_ssa_def *miss_sbt_addr =
+ nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
+ nir_u2u64(b, miss_sbt_offset_B));
+
+ struct brw_nir_rt_mem_ray_defs ray_defs = {
+ .root_node_ptr = root_node_ptr,
+ .ray_flags = nir_u2u16(b, ray_flags),
+ .ray_mask = cull_mask,
+ .hit_group_sr_base_ptr = hit_sbt_addr,
+ .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
+ .miss_sr_ptr = miss_sbt_addr,
+ .orig = ray_orig,
+ .t_near = ray_t_min,
+ .dir = ray_dir,
+ .t_far = ray_t_max,
+ .shader_index_multiplier = sbt_stride,
+ };
+ brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
+
+ nir_trace_ray_intel(b,
+ nir_load_btd_global_arg_addr_intel(b),
+ nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
+ nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
+ .synchronous = false);
+ return true;
+}
- nir_ssa_def *sbt_offset32 =
- nir_imul(b, call->src[0].ssa,
- nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
- nir_ssa_def *sbt_addr =
- nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
- nir_u2u64(b, sbt_offset32));
- brw_nir_btd_spawn(b, sbt_addr);
- return true;
- }
+static bool
+lower_shader_call_instr(struct nir_builder *b, nir_instr *instr, void *data)
+{
+ if (instr->type != nir_instr_type_intrinsic)
+ return false;
- default:
+ /* Leave nir_intrinsic_rt_resume to be lowered by
+ * brw_nir_lower_rt_intrinsics()
+ */
+ nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
+ if (call->intrinsic != nir_intrinsic_rt_execute_callable)
return false;
- }
+
+ b->cursor = nir_instr_remove(instr);
+
+ store_resume_addr(b, call);
+
+ nir_ssa_def *sbt_offset32 =
+ nir_imul(b, call->src[0].ssa,
+ nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
+ nir_ssa_def *sbt_addr =
+ nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
+ nir_u2u64(b, sbt_offset32));
+ brw_nir_btd_spawn(b, sbt_addr);
+ return true;
}
bool
brw_nir_lower_shader_calls(nir_shader *shader)
{
- return nir_shader_instructions_pass(shader,
- lower_shader_calls_instr,
- nir_metadata_block_index |
- nir_metadata_dominance,
- NULL);
+ return
+ nir_shader_instructions_pass(shader,
+ lower_shader_trace_ray_instr,
+ nir_metadata_none,
+ NULL) |
+ nir_shader_instructions_pass(shader,
+ lower_shader_call_instr,
+ nir_metadata_block_index |
+ nir_metadata_dominance,
+ NULL);
}
/** Creates a trivial return shader