anv: refactor ray tracing dispatch
authorLionel Landwerlin <lionel.g.landwerlin@intel.com>
Fri, 25 Nov 2022 20:01:10 +0000 (22:01 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 2 Dec 2022 09:28:23 +0000 (09:28 +0000)
Preparing for vkCmdTraceRaysIndirect2KHR

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Ivan Briano <ivan.briano@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20011>

src/intel/genxml/meson.build
src/intel/vulkan/genX_cmd_buffer.c

index 2f2d6c4..c1c68fa 100644 (file)
@@ -73,6 +73,12 @@ genX_bits_included_symbols = [
   'CLEAR_COLOR',
   'VERTEX_BUFFER_STATE::Buffer Starting Address',
   'CPS_STATE',
+  'RT_DISPATCH_GLOBALS::Hit Group Table',
+  'RT_DISPATCH_GLOBALS::Miss Group Table',
+  'RT_DISPATCH_GLOBALS::Callable Group Table',
+  'RT_DISPATCH_GLOBALS::Launch Width',
+  'RT_DISPATCH_GLOBALS::Launch Height',
+  'RT_DISPATCH_GLOBALS::Launch Depth',
 ]
 
 genX_bits_h = custom_target(
index 7c5a4d0..f449634 100644 (file)
@@ -5617,17 +5617,71 @@ vk_sdar_to_shader_table(const VkStridedDeviceAddressRegionKHR *region)
    };
 }
 
+struct trace_params {
+   /* If is_sbt_indirect, use indirect_sbts_addr to build RT_DISPATCH_GLOBALS
+    * with mi_builder.
+    */
+   bool is_sbt_indirect;
+   const VkStridedDeviceAddressRegionKHR *raygen_sbt;
+   const VkStridedDeviceAddressRegionKHR *miss_sbt;
+   const VkStridedDeviceAddressRegionKHR *hit_sbt;
+   const VkStridedDeviceAddressRegionKHR *callable_sbt;
+
+   /* A pointer to a VkTraceRaysIndirectCommand2KHR structure */
+   uint64_t indirect_sbts_addr;
+
+   /* If is_indirect, use launch_size_addr to program the dispatch size. */
+   bool is_launch_size_indirect;
+   uint32_t launch_size[3];
+
+   /* A pointer a uint32_t[3] */
+   uint64_t launch_size_addr;
+};
+
+static struct anv_state
+cmd_buffer_emit_rt_dispatch_globals(struct anv_cmd_buffer *cmd_buffer,
+                                    struct trace_params *params)
+{
+   assert(!params->is_sbt_indirect);
+   assert(params->miss_sbt != NULL);
+   assert(params->hit_sbt != NULL);
+   assert(params->callable_sbt != NULL);
+
+   struct anv_cmd_ray_tracing_state *rt = &cmd_buffer->state.rt;
+
+   struct anv_state rtdg_state =
+      anv_cmd_buffer_alloc_dynamic_state(cmd_buffer,
+                                         BRW_RT_PUSH_CONST_OFFSET +
+                                         sizeof(struct anv_push_constants),
+                                         64);
+
+   struct GENX(RT_DISPATCH_GLOBALS) rtdg = {
+      .MemBaseAddress     = (struct anv_address) {
+         .bo = rt->scratch.bo,
+         .offset = rt->scratch.layout.ray_stack_start,
+      },
+      .CallStackHandler   = anv_shader_bin_get_bsr(
+         cmd_buffer->device->rt_trivial_return, 0),
+      .AsyncRTStackSize   = rt->scratch.layout.ray_stack_stride / 64,
+      .NumDSSRTStacks     = rt->scratch.layout.stack_ids_per_dss,
+      .MaxBVHLevels       = BRW_RT_MAX_BVH_LEVELS,
+      .Flags              = RT_DEPTH_TEST_LESS_EQUAL,
+      .HitGroupTable      = vk_sdar_to_shader_table(params->hit_sbt),
+      .MissGroupTable     = vk_sdar_to_shader_table(params->miss_sbt),
+      .SWStackSize        = rt->scratch.layout.sw_stack_size / 64,
+      .LaunchWidth        = params->launch_size[0],
+      .LaunchHeight       = params->launch_size[1],
+      .LaunchDepth        = params->launch_size[2],
+      .CallableGroupTable = vk_sdar_to_shader_table(params->callable_sbt),
+   };
+   GENX(RT_DISPATCH_GLOBALS_pack)(NULL, rtdg_state.map, &rtdg);
+
+   return rtdg_state;
+}
+
 static void
 cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
-                      const VkStridedDeviceAddressRegionKHR *raygen_sbt,
-                      const VkStridedDeviceAddressRegionKHR *miss_sbt,
-                      const VkStridedDeviceAddressRegionKHR *hit_sbt,
-                      const VkStridedDeviceAddressRegionKHR *callable_sbt,
-                      bool is_indirect,
-                      uint32_t launch_width,
-                      uint32_t launch_height,
-                      uint32_t launch_depth,
-                      uint64_t launch_size_addr)
+                      struct trace_params *params)
 {
    struct anv_device *device = cmd_buffer->device;
    struct anv_cmd_ray_tracing_state *rt = &cmd_buffer->state.rt;
@@ -5637,8 +5691,10 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
       return;
 
    /* If we have a known degenerate launch size, just bail */
-   if (!is_indirect &&
-       (launch_width == 0 || launch_height == 0 || launch_depth == 0))
+   if (!params->is_launch_size_indirect &&
+       (params->launch_size[0] == 0 ||
+        params->launch_size[1] == 0 ||
+        params->launch_size[2] == 0))
       return;
 
    genX(cmd_buffer_config_l3)(cmd_buffer, pipeline->base.l3_config);
@@ -5662,34 +5718,12 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
 
    /* Allocate and set up our RT_DISPATCH_GLOBALS */
    struct anv_state rtdg_state =
-      anv_cmd_buffer_alloc_dynamic_state(cmd_buffer,
-                                         BRW_RT_PUSH_CONST_OFFSET +
-                                         sizeof(struct anv_push_constants),
-                                         64);
-
-   struct GENX(RT_DISPATCH_GLOBALS) rtdg = {
-      .MemBaseAddress = (struct anv_address) {
-         .bo = rt->scratch.bo,
-         .offset = rt->scratch.layout.ray_stack_start,
-      },
-      .CallStackHandler =
-         anv_shader_bin_get_bsr(cmd_buffer->device->rt_trivial_return, 0),
-      .AsyncRTStackSize = rt->scratch.layout.ray_stack_stride / 64,
-      .NumDSSRTStacks = rt->scratch.layout.stack_ids_per_dss,
-      .MaxBVHLevels = BRW_RT_MAX_BVH_LEVELS,
-      .Flags = RT_DEPTH_TEST_LESS_EQUAL,
-      .HitGroupTable = vk_sdar_to_shader_table(hit_sbt),
-      .MissGroupTable = vk_sdar_to_shader_table(miss_sbt),
-      .SWStackSize = rt->scratch.layout.sw_stack_size / 64,
-      .LaunchWidth = launch_width,
-      .LaunchHeight = launch_height,
-      .LaunchDepth = launch_depth,
-      .CallableGroupTable = vk_sdar_to_shader_table(callable_sbt),
-   };
-   GENX(RT_DISPATCH_GLOBALS_pack)(NULL, rtdg_state.map, &rtdg);
+      cmd_buffer_emit_rt_dispatch_globals(cmd_buffer, params);
 
-   /* Push constants go after the RT_DISPATCH_GLOBALS */
+   assert(rtdg_state.alloc_size >= (BRW_RT_PUSH_CONST_OFFSET +
+                                    sizeof(struct anv_push_constants)));
    assert(GENX(RT_DISPATCH_GLOBALS_length) * 4 <= BRW_RT_PUSH_CONST_OFFSET);
+   /* Push constants go after the RT_DISPATCH_GLOBALS */
    memcpy(rtdg_state.map + BRW_RT_PUSH_CONST_OFFSET,
           &cmd_buffer->state.rt.base.push_constants,
           sizeof(struct anv_push_constants));
@@ -5700,7 +5734,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
 
    uint8_t local_size_log2[3];
    uint32_t global_size[3] = {};
-   if (is_indirect) {
+   if (params->is_launch_size_indirect) {
       /* Pick a local size that's probably ok.  We assume most TraceRays calls
        * will use a two-dimensional dispatch size.  Worst case, our initial
        * dispatch will be a little slower than it has to be.
@@ -5713,21 +5747,20 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
       mi_builder_init(&b, cmd_buffer->device->info, &cmd_buffer->batch);
 
       struct mi_value launch_size[3] = {
-         mi_mem32(anv_address_from_u64(launch_size_addr + 0)),
-         mi_mem32(anv_address_from_u64(launch_size_addr + 4)),
-         mi_mem32(anv_address_from_u64(launch_size_addr + 8)),
+         mi_mem32(anv_address_from_u64(params->launch_size_addr + 0)),
+         mi_mem32(anv_address_from_u64(params->launch_size_addr + 4)),
+         mi_mem32(anv_address_from_u64(params->launch_size_addr + 8)),
       };
 
-      /* Store the original launch size into RT_DISPATCH_GLOBALS
-       *
-       * TODO: Pull values from genX_bits.h once RT_DISPATCH_GLOBALS gets
-       * moved into a genX version.
-       */
-      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 52)),
+      /* Store the original launch size into RT_DISPATCH_GLOBALS */
+      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
+                                            GENX(RT_DISPATCH_GLOBALS_LaunchWidth_start) / 8)),
                mi_value_ref(&b, launch_size[0]));
-      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 56)),
+      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
+                                            GENX(RT_DISPATCH_GLOBALS_LaunchHeight_start) / 8)),
                mi_value_ref(&b, launch_size[1]));
-      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 60)),
+      mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
+                                            GENX(RT_DISPATCH_GLOBALS_LaunchDepth_start) / 8)),
                mi_value_ref(&b, launch_size[2]));
 
       /* Compute the global dispatch size */
@@ -5752,15 +5785,14 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
       mi_store(&b, mi_reg32(GPGPU_DISPATCHDIMY), launch_size[1]);
       mi_store(&b, mi_reg32(GPGPU_DISPATCHDIMZ), launch_size[2]);
    } else {
-      uint32_t launch_size[3] = { launch_width, launch_height, launch_depth };
-      calc_local_trace_size(local_size_log2, launch_size);
+      calc_local_trace_size(local_size_log2, params->launch_size);
 
       for (unsigned i = 0; i < 3; i++) {
          /* We have to be a bit careful here because DIV_ROUND_UP adds to the
           * numerator value may overflow.  Cast to uint64_t to avoid this.
           */
          uint32_t local_size = 1 << local_size_log2[i];
-         global_size[i] = DIV_ROUND_UP((uint64_t)launch_size[i], local_size);
+         global_size[i] = DIV_ROUND_UP((uint64_t)params->launch_size[i], local_size);
       }
    }
 
@@ -5799,7 +5831,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
       brw_cs_get_dispatch_info(device->info, cs_prog_data, NULL);
 
    anv_batch_emit(&cmd_buffer->batch, GENX(COMPUTE_WALKER), cw) {
-      cw.IndirectParameterEnable        = is_indirect;
+      cw.IndirectParameterEnable        = params->is_launch_size_indirect;
       cw.PredicateEnable                = cmd_buffer->state.conditional_render_enabled;
       cw.SIMDSize                       = dispatch.simd_size / 16;
       cw.LocalXMaximum                  = (1 << local_size_log2[0]) - 1;
@@ -5828,7 +5860,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
 
       struct brw_rt_raygen_trampoline_params trampoline_params = {
          .rt_disp_globals_addr = anv_address_physical(rtdg_addr),
-         .raygen_bsr_addr = raygen_sbt->deviceAddress,
+         .raygen_bsr_addr = params->raygen_sbt->deviceAddress,
          .is_indirect = false, /* Only for raygen_bsr_addr */
          .local_group_size_log2 = {
             local_size_log2[0],
@@ -5853,15 +5885,21 @@ genX(CmdTraceRaysKHR)(
     uint32_t                                    depth)
 {
    ANV_FROM_HANDLE(anv_cmd_buffer, cmd_buffer, commandBuffer);
+   struct trace_params params = {
+      .is_sbt_indirect         = false,
+      .raygen_sbt              = pRaygenShaderBindingTable,
+      .miss_sbt                = pMissShaderBindingTable,
+      .hit_sbt                 = pHitShaderBindingTable,
+      .callable_sbt            = pCallableShaderBindingTable,
+      .is_launch_size_indirect = false,
+      .launch_size             = {
+         width,
+         height,
+         depth,
+      },
+   };
 
-   cmd_buffer_trace_rays(cmd_buffer,
-                         pRaygenShaderBindingTable,
-                         pMissShaderBindingTable,
-                         pHitShaderBindingTable,
-                         pCallableShaderBindingTable,
-                         false /* is_indirect */,
-                         width, height, depth,
-                         0 /* launch_size_addr */);
+   cmd_buffer_trace_rays(cmd_buffer, &params);
 }
 
 void
@@ -5874,15 +5912,17 @@ genX(CmdTraceRaysIndirectKHR)(
     VkDeviceAddress                             indirectDeviceAddress)
 {
    ANV_FROM_HANDLE(anv_cmd_buffer, cmd_buffer, commandBuffer);
+   struct trace_params params = {
+      .is_sbt_indirect         = false,
+      .raygen_sbt              = pRaygenShaderBindingTable,
+      .miss_sbt                = pMissShaderBindingTable,
+      .hit_sbt                 = pHitShaderBindingTable,
+      .callable_sbt            = pCallableShaderBindingTable,
+      .is_launch_size_indirect = true,
+      .launch_size_addr        = indirectDeviceAddress,
+   };
 
-   cmd_buffer_trace_rays(cmd_buffer,
-                         pRaygenShaderBindingTable,
-                         pMissShaderBindingTable,
-                         pHitShaderBindingTable,
-                         pCallableShaderBindingTable,
-                         true /* is_indirect */,
-                         0, 0, 0, /* width, height, depth, */
-                         indirectDeviceAddress);
+   cmd_buffer_trace_rays(cmd_buffer, &params);
 }
 #endif /* GFX_VERx10 >= 125 */