radv: Perform multiple sorts in parallel
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Wed, 16 Aug 2023 09:50:18 +0000 (11:50 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 14 Sep 2023 15:12:44 +0000 (15:12 +0000)
This was the last part that didn't scale with multiple infos. Reducing
the amount of barriers in this case improves DOOM Eternal performance by
50%. (Running with low resolution)

Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24720>

src/amd/vulkan/radv_acceleration_structure.c

index 81c19b9..271cad0 100644 (file)
@@ -600,6 +600,13 @@ struct bvh_state {
    struct acceleration_structure_layout accel_struct;
    struct scratch_layout scratch;
    struct build_config config;
+
+   /* Radix sort state */
+   uint32_t scatter_blocks;
+   uint32_t count_ru_scatter;
+   uint32_t histo_blocks;
+   uint32_t count_ru_histo;
+   struct rs_push_scatter push_scatter;
 };
 
 static uint32_t
@@ -726,75 +733,79 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
 
    radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort;
 
-   for (uint32_t i = 0; i < infoCount; ++i) {
-      uint32_t count = bvh_states[i].node_count;
-      uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
-      uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
-      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+   /*
+    * OVERVIEW
+    *
+    *   1. Pad the keyvals in `scatter_even`.
+    *   2. Zero the `histograms` and `partitions`.
+    *      --- BARRIER ---
+    *   3. HISTOGRAM is dispatched before PREFIX.
+    *      --- BARRIER ---
+    *   4. PREFIX is dispatched before the first SCATTER.
+    *      --- BARRIER ---
+    *   5. One or more SCATTER dispatches.
+    *
+    * Note that the `partitions` buffer can be zeroed anytime before the first
+    * scatter.
+    */
+
+   /* How many passes? */
+   uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
+   uint32_t keyval_bits = keyval_bytes * 8;
+   uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
+   uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
 
-      /* Anything to do? */
-      if (!count) {
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      if (bvh_states[i].node_count)
+         bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
+      else
          bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
-         continue;
-      }
-
-      /*
-       * OVERVIEW
-       *
-       *   1. Pad the keyvals in `scatter_even`.
-       *   2. Zero the `histograms` and `partitions`.
-       *      --- BARRIER ---
-       *   3. HISTOGRAM is dispatched before PREFIX.
-       *      --- BARRIER ---
-       *   4. PREFIX is dispatched before the first SCATTER.
-       *      --- BARRIER ---
-       *   5. One or more SCATTER dispatches.
-       *
-       * Note that the `partitions` buffer can be zeroed anytime before the first
-       * scatter.
-       */
-
-      /* How many passes? */
-      uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
-      uint32_t keyval_bits = keyval_bytes * 8;
-      uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
-      uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
+   }
 
-      bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
+   /*
+    * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
+    *
+    * Pad fractional blocks with max-valued keyvals.
+    *
+    * Zero the histograms and partitions buffer.
+    *
+    * This assumes the partitions follow the histograms.
+    */
+
+   /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
+
+   /* How many scatter blocks? */
+   uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
+   uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
+
+   /*
+    * How many histogram blocks?
+    *
+    * Note that it's OK to have more max-valued digits counted by the histogram
+    * than sorted by the scatters because the sort is stable.
+    */
+   uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
+   uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
+
+   uint32_t pass_idx = (keyval_bytes - passes);
 
-      /*
-       * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
-       *
-       * Pad fractional blocks with max-valued keyvals.
-       *
-       * Zero the histograms and partitions buffer.
-       *
-       * This assumes the partitions follow the histograms.
-       */
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      if (!bvh_states[i].node_count)
+         continue;
 
-      /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
+      uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
 
-      /* How many scatter blocks? */
-      uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
-      uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
-      uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs;
-      uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs;
+      bvh_states[i].scatter_blocks = (bvh_states[i].node_count + scatter_block_kvs - 1) / scatter_block_kvs;
+      bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs;
 
-      /*
-       * How many histogram blocks?
-       *
-       * Note that it's OK to have more max-valued digits counted by the histogram
-       * than sorted by the scatters because the sort is stable.
-       */
-      uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
-      uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
-      uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
-      uint32_t count_ru_histo = histo_blocks * histo_block_kvs;
+      bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
+      bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs;
 
       /* Fill with max values */
-      if (count_ru_histo > count) {
-         radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes,
-                          (count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF);
+      if (bvh_states[i].count_ru_histo > bvh_states[i].node_count) {
+         radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + bvh_states[i].node_count * keyval_bytes,
+                          (bvh_states[i].count_ru_histo - bvh_states[i].node_count) * keyval_bytes, 0xFFFFFFFF);
       }
 
       /*
@@ -807,28 +818,35 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
        * Note that the last workgroup doesn't read/write a partition so it doesn't
        * need to be initialized.
        */
-      uint32_t histo_partition_count = passes + scatter_blocks - 1;
-      uint32_t pass_idx = (keyval_bytes - passes);
+      uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1;
 
       uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
 
       radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base,
                        histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0);
+   }
 
-      /*
-       * Pipeline: HISTOGRAM
-       *
-       * TODO(allanmac): All subgroups should try to process approximately the same
-       * number of blocks in order to minimize tail effects.  This was implemented
-       * and reverted but should be reimplemented and benchmarked later.
-       */
-      vk_barrier_transfer_w_to_compute_r(commandBuffer);
+   /*
+    * Pipeline: HISTOGRAM
+    *
+    * TODO(allanmac): All subgroups should try to process approximately the same
+    * number of blocks in order to minimize tail effects.  This was implemented
+    * and reverted but should be reimplemented and benchmarked later.
+    */
+   vk_barrier_transfer_w_to_compute_r(commandBuffer);
+
+   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
+
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      if (!bvh_states[i].node_count)
+         continue;
 
-      uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset;
+      uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
 
       /* Dispatch histogram */
       struct rs_push_histogram push_histogram = {
-         .devaddr_histograms = devaddr_histograms,
+         .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
          .devaddr_keyvals = keyvals_even_addr,
          .passes = passes,
       };
@@ -836,83 +854,87 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
       radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
                             sizeof(push_histogram), &push_histogram);
 
-      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
+      vk_common_CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1);
+   }
 
-      vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1);
+   /*
+    * Pipeline: PREFIX
+    *
+    * Launch one workgroup per pass.
+    */
+   vk_barrier_compute_w_to_compute_r(commandBuffer);
 
-      /*
-       * Pipeline: PREFIX
-       *
-       * Launch one workgroup per pass.
-       */
-      vk_barrier_compute_w_to_compute_r(commandBuffer);
+   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
+
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      if (!bvh_states[i].node_count)
+         continue;
+
+      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
 
       struct rs_push_prefix push_prefix = {
-         .devaddr_histograms = devaddr_histograms,
+         .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
       };
 
       radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0,
                             sizeof(push_prefix), &push_prefix);
 
-      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
-
       vk_common_CmdDispatch(commandBuffer, passes, 1, 1);
+   }
 
-      /* Pipeline: SCATTER */
-      vk_barrier_compute_w_to_compute_r(commandBuffer);
+   /* Pipeline: SCATTER */
+   vk_barrier_compute_w_to_compute_r(commandBuffer);
 
-      uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
-      uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset;
+   uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
 
-      struct rs_push_scatter push_scatter = {
+   for (uint32_t i = 0; i < infoCount; i++) {
+      uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+      uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
+      uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+
+      bvh_states[i].push_scatter = (struct rs_push_scatter){
          .devaddr_keyvals_even = keyvals_even_addr,
          .devaddr_keyvals_odd = keyvals_odd_addr,
-         .devaddr_partitions = devaddr_partitions,
-         .devaddr_histograms = devaddr_histograms + histogram_offset,
-         .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
+         .devaddr_partitions = internal_addr + rs->internal.partitions.offset,
+         .devaddr_histograms = internal_addr + rs->internal.histograms.offset + histogram_offset,
       };
+   }
 
-      {
-         uint32_t pass_dword = pass_idx / 4;
+   bool is_even = true;
 
-         radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even,
-                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter);
+   while (true) {
+      uint32_t pass_dword = pass_idx / 4;
 
-         radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
-                              rs->pipelines.named.scatter[pass_dword].even);
-      }
+      /* Bind new pipeline */
+      VkPipeline p =
+         is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
+      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
 
-      bool is_even = true;
+      /* Update push constants that changed */
+      VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
+                                    : rs->pipeline_layouts.named.scatter[pass_dword].odd;
 
-      while (true) {
-         vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1);
+      for (uint32_t i = 0; i < infoCount; i++) {
+         if (!bvh_states[i].node_count)
+            continue;
 
-         /* Continue? */
-         if (++pass_idx >= keyval_bytes)
-            break;
+         bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
 
-         vk_barrier_compute_w_to_compute_r(commandBuffer);
+         radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter),
+                               &bvh_states[i].push_scatter);
 
-         is_even ^= true;
-         push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
-         push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
+         vk_common_CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1);
 
-         uint32_t pass_dword = pass_idx / 4;
+         bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
+      }
 
-         /* Update push constants that changed */
-         VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
-                                       : rs->pipeline_layouts.named.scatter[pass_dword].odd;
-         radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT,
-                               offsetof(struct rs_push_scatter, devaddr_histograms),
-                               sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
-                               &push_scatter.devaddr_histograms);
+      /* Continue? */
+      if (++pass_idx >= keyval_bytes)
+         break;
 
-         /* Bind new pipeline */
-         VkPipeline p =
-            is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
+      vk_barrier_compute_w_to_compute_r(commandBuffer);
 
-         radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
-      }
+      is_even ^= true;
    }
 
    cmd_buffer->state.flush_bits |= flush_bits;