radv/radix_sort: Vendor the radix sort dispatch code
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Wed, 16 Aug 2023 09:09:25 +0000 (11:09 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 14 Sep 2023 15:12:44 +0000 (15:12 +0000)
This needs to be done so we can optimize it for occpuancy when building
multiple acceleration structures in parallel. Changes to the original
code:

- Change // to /* */
- clang-format
- Replace vkCmd calls with calls to the driver entrypoints
- Add a light weight info struct
- Use radv_fill_buffer directly

Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24720>

src/amd/vulkan/radv_acceleration_structure.c
src/amd/vulkan/radv_private.h

index 64be510..81c19b9 100644 (file)
@@ -27,7 +27,9 @@
 #include "nir_builder.h"
 #include "radv_cs.h"
 
+#include "radix_sort/common/vk/barrier.h"
 #include "radix_sort/radv_radix_sort.h"
+#include "radix_sort/shaders/push.h"
 
 #include "bvh/build_interface.h"
 #include "bvh/bvh.h"
@@ -76,6 +78,7 @@ static const uint32_t header_spv[] = {
 };
 
 #define KEY_ID_PAIR_SIZE 8
+#define MORTON_BIT_SIZE  24
 
 enum internal_build_type {
    INTERNAL_BUILD_TYPE_LBVH,
@@ -385,17 +388,6 @@ cleanup:
    return result;
 }
 
-static void
-radix_sort_fill_buffer(VkCommandBuffer commandBuffer, radix_sort_vk_buffer_info_t const *buffer_info,
-                       VkDeviceSize offset, VkDeviceSize size, uint32_t data)
-{
-   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
-
-   assert(size != VK_WHOLE_SIZE);
-
-   radv_fill_buffer(cmd_buffer, NULL, NULL, buffer_info->devaddr + buffer_info->offset + offset, size, data);
-}
-
 VkResult
 radv_device_init_null_accel_struct(struct radv_device *device)
 {
@@ -579,12 +571,6 @@ radv_device_init_accel_struct_build_state(struct radv_device *device)
 
    device->meta_state.accel_struct_build.radix_sort =
       radv_create_radix_sort_u64(radv_device_to_handle(device), &device->meta_state.alloc, device->meta_state.cache);
-
-   struct radix_sort_vk_sort_devaddr_info *radix_sort_info = &device->meta_state.accel_struct_build.radix_sort_info;
-   radix_sort_info->ext = NULL;
-   radix_sort_info->key_bits = 24;
-   radix_sort_info->fill_buffer = radix_sort_fill_buffer;
-
 exit:
    mtx_unlock(&device->meta_state.mtx);
    return result;
@@ -735,28 +721,198 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
             const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, struct bvh_state *bvh_states,
             enum radv_cmd_flush_bits flush_bits)
 {
+   /* Copyright 2019 The Fuchsia Authors. */
    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
+
+   radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort;
+
    for (uint32_t i = 0; i < infoCount; ++i) {
-      struct radix_sort_vk_sort_devaddr_info info = cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info;
-      info.count = bvh_states[i].node_count;
+      uint32_t count = bvh_states[i].node_count;
+      uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+      uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
+      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+
+      /* Anything to do? */
+      if (!count) {
+         bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
+         continue;
+      }
+
+      /*
+       * OVERVIEW
+       *
+       *   1. Pad the keyvals in `scatter_even`.
+       *   2. Zero the `histograms` and `partitions`.
+       *      --- BARRIER ---
+       *   3. HISTOGRAM is dispatched before PREFIX.
+       *      --- BARRIER ---
+       *   4. PREFIX is dispatched before the first SCATTER.
+       *      --- BARRIER ---
+       *   5. One or more SCATTER dispatches.
+       *
+       * Note that the `partitions` buffer can be zeroed anytime before the first
+       * scatter.
+       */
+
+      /* How many passes? */
+      uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
+      uint32_t keyval_bits = keyval_bytes * 8;
+      uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
+      uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
+
+      bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
+
+      /*
+       * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
+       *
+       * Pad fractional blocks with max-valued keyvals.
+       *
+       * Zero the histograms and partitions buffer.
+       *
+       * This assumes the partitions follow the histograms.
+       */
+
+      /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
+
+      /* How many scatter blocks? */
+      uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
+      uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
+      uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs;
+      uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs;
+
+      /*
+       * How many histogram blocks?
+       *
+       * Note that it's OK to have more max-valued digits counted by the histogram
+       * than sorted by the scatters because the sort is stable.
+       */
+      uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
+      uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
+      uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
+      uint32_t count_ru_histo = histo_blocks * histo_block_kvs;
+
+      /* Fill with max values */
+      if (count_ru_histo > count) {
+         radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes,
+                          (count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF);
+      }
+
+      /*
+       * Zero histograms and invalidate partitions.
+       *
+       * Note that the partition invalidation only needs to be performed once
+       * because the even/odd scatter dispatches rely on the the previous pass to
+       * leave the partitions in an invalid state.
+       *
+       * Note that the last workgroup doesn't read/write a partition so it doesn't
+       * need to be initialized.
+       */
+      uint32_t histo_partition_count = passes + scatter_blocks - 1;
+      uint32_t pass_idx = (keyval_bytes - passes);
+
+      uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
+
+      radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base,
+                       histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0);
+
+      /*
+       * Pipeline: HISTOGRAM
+       *
+       * TODO(allanmac): All subgroups should try to process approximately the same
+       * number of blocks in order to minimize tail effects.  This was implemented
+       * and reverted but should be reimplemented and benchmarked later.
+       */
+      vk_barrier_transfer_w_to_compute_r(commandBuffer);
+
+      uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset;
+
+      /* Dispatch histogram */
+      struct rs_push_histogram push_histogram = {
+         .devaddr_histograms = devaddr_histograms,
+         .devaddr_keyvals = keyvals_even_addr,
+         .passes = passes,
+      };
+
+      radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
+                            sizeof(push_histogram), &push_histogram);
+
+      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
+
+      vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1);
 
-      info.keyvals_even.buffer = VK_NULL_HANDLE;
-      info.keyvals_even.offset = 0;
-      info.keyvals_even.devaddr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+      /*
+       * Pipeline: PREFIX
+       *
+       * Launch one workgroup per pass.
+       */
+      vk_barrier_compute_w_to_compute_r(commandBuffer);
+
+      struct rs_push_prefix push_prefix = {
+         .devaddr_histograms = devaddr_histograms,
+      };
+
+      radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0,
+                            sizeof(push_prefix), &push_prefix);
 
-      info.keyvals_odd = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
+      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
 
-      info.internal.buffer = VK_NULL_HANDLE;
-      info.internal.offset = 0;
-      info.internal.devaddr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+      vk_common_CmdDispatch(commandBuffer, passes, 1, 1);
 
-      VkDeviceAddress result_addr;
-      radix_sort_vk_sort_devaddr(cmd_buffer->device->meta_state.accel_struct_build.radix_sort, &info,
-                                 radv_device_to_handle(cmd_buffer->device), commandBuffer, &result_addr);
+      /* Pipeline: SCATTER */
+      vk_barrier_compute_w_to_compute_r(commandBuffer);
+
+      uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
+      uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset;
+
+      struct rs_push_scatter push_scatter = {
+         .devaddr_keyvals_even = keyvals_even_addr,
+         .devaddr_keyvals_odd = keyvals_odd_addr,
+         .devaddr_partitions = devaddr_partitions,
+         .devaddr_histograms = devaddr_histograms + histogram_offset,
+         .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
+      };
 
-      assert(result_addr == info.keyvals_even.devaddr || result_addr == info.keyvals_odd);
+      {
+         uint32_t pass_dword = pass_idx / 4;
 
-      bvh_states[i].scratch_offset = (uint32_t)(result_addr - pInfos[i].scratchData.deviceAddress);
+         radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even,
+                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter);
+
+         radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
+                              rs->pipelines.named.scatter[pass_dword].even);
+      }
+
+      bool is_even = true;
+
+      while (true) {
+         vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1);
+
+         /* Continue? */
+         if (++pass_idx >= keyval_bytes)
+            break;
+
+         vk_barrier_compute_w_to_compute_r(commandBuffer);
+
+         is_even ^= true;
+         push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
+         push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
+
+         uint32_t pass_dword = pass_idx / 4;
+
+         /* Update push constants that changed */
+         VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
+                                       : rs->pipeline_layouts.named.scatter[pass_dword].odd;
+         radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT,
+                               offsetof(struct rs_push_scatter, devaddr_histograms),
+                               sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
+                               &push_scatter.devaddr_histograms);
+
+         /* Bind new pipeline */
+         VkPipeline p =
+            is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
+
+         radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
+      }
    }
 
    cmd_buffer->state.flush_bits |= flush_bits;
index 4cd9623..caa32e0 100644 (file)
@@ -731,7 +731,6 @@ struct radv_meta_state {
       VkPipeline copy_pipeline;
 
       struct radix_sort_vk *radix_sort;
-      struct radix_sort_vk_sort_devaddr_info radix_sort_info;
 
       struct {
          VkBuffer buffer;