struct acceleration_structure_layout accel_struct;
struct scratch_layout scratch;
struct build_config config;
+
+ /* Radix sort state */
+ uint32_t scatter_blocks;
+ uint32_t count_ru_scatter;
+ uint32_t histo_blocks;
+ uint32_t count_ru_histo;
+ struct rs_push_scatter push_scatter;
};
static uint32_t
radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort;
- for (uint32_t i = 0; i < infoCount; ++i) {
- uint32_t count = bvh_states[i].node_count;
- uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
- uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
- uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+ /*
+ * OVERVIEW
+ *
+ * 1. Pad the keyvals in `scatter_even`.
+ * 2. Zero the `histograms` and `partitions`.
+ * --- BARRIER ---
+ * 3. HISTOGRAM is dispatched before PREFIX.
+ * --- BARRIER ---
+ * 4. PREFIX is dispatched before the first SCATTER.
+ * --- BARRIER ---
+ * 5. One or more SCATTER dispatches.
+ *
+ * Note that the `partitions` buffer can be zeroed anytime before the first
+ * scatter.
+ */
+
+ /* How many passes? */
+ uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
+ uint32_t keyval_bits = keyval_bytes * 8;
+ uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
+ uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
- /* Anything to do? */
- if (!count) {
+ for (uint32_t i = 0; i < infoCount; ++i) {
+ if (bvh_states[i].node_count)
+ bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
+ else
bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
- continue;
- }
-
- /*
- * OVERVIEW
- *
- * 1. Pad the keyvals in `scatter_even`.
- * 2. Zero the `histograms` and `partitions`.
- * --- BARRIER ---
- * 3. HISTOGRAM is dispatched before PREFIX.
- * --- BARRIER ---
- * 4. PREFIX is dispatched before the first SCATTER.
- * --- BARRIER ---
- * 5. One or more SCATTER dispatches.
- *
- * Note that the `partitions` buffer can be zeroed anytime before the first
- * scatter.
- */
-
- /* How many passes? */
- uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
- uint32_t keyval_bits = keyval_bytes * 8;
- uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
- uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
+ }
- bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
+ /*
+ * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
+ *
+ * Pad fractional blocks with max-valued keyvals.
+ *
+ * Zero the histograms and partitions buffer.
+ *
+ * This assumes the partitions follow the histograms.
+ */
+
+ /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
+
+ /* How many scatter blocks? */
+ uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
+ uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
+
+ /*
+ * How many histogram blocks?
+ *
+ * Note that it's OK to have more max-valued digits counted by the histogram
+ * than sorted by the scatters because the sort is stable.
+ */
+ uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
+ uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
+
+ uint32_t pass_idx = (keyval_bytes - passes);
- /*
- * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
- *
- * Pad fractional blocks with max-valued keyvals.
- *
- * Zero the histograms and partitions buffer.
- *
- * This assumes the partitions follow the histograms.
- */
+ for (uint32_t i = 0; i < infoCount; ++i) {
+ if (!bvh_states[i].node_count)
+ continue;
- /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
+ uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+ uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
- /* How many scatter blocks? */
- uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
- uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
- uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs;
- uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs;
+ bvh_states[i].scatter_blocks = (bvh_states[i].node_count + scatter_block_kvs - 1) / scatter_block_kvs;
+ bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs;
- /*
- * How many histogram blocks?
- *
- * Note that it's OK to have more max-valued digits counted by the histogram
- * than sorted by the scatters because the sort is stable.
- */
- uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
- uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
- uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
- uint32_t count_ru_histo = histo_blocks * histo_block_kvs;
+ bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
+ bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs;
/* Fill with max values */
- if (count_ru_histo > count) {
- radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes,
- (count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF);
+ if (bvh_states[i].count_ru_histo > bvh_states[i].node_count) {
+ radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + bvh_states[i].node_count * keyval_bytes,
+ (bvh_states[i].count_ru_histo - bvh_states[i].node_count) * keyval_bytes, 0xFFFFFFFF);
}
/*
* Note that the last workgroup doesn't read/write a partition so it doesn't
* need to be initialized.
*/
- uint32_t histo_partition_count = passes + scatter_blocks - 1;
- uint32_t pass_idx = (keyval_bytes - passes);
+ uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1;
uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base,
histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0);
+ }
- /*
- * Pipeline: HISTOGRAM
- *
- * TODO(allanmac): All subgroups should try to process approximately the same
- * number of blocks in order to minimize tail effects. This was implemented
- * and reverted but should be reimplemented and benchmarked later.
- */
- vk_barrier_transfer_w_to_compute_r(commandBuffer);
+ /*
+ * Pipeline: HISTOGRAM
+ *
+ * TODO(allanmac): All subgroups should try to process approximately the same
+ * number of blocks in order to minimize tail effects. This was implemented
+ * and reverted but should be reimplemented and benchmarked later.
+ */
+ vk_barrier_transfer_w_to_compute_r(commandBuffer);
+
+ radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
+
+ for (uint32_t i = 0; i < infoCount; ++i) {
+ if (!bvh_states[i].node_count)
+ continue;
- uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset;
+ uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+ uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
/* Dispatch histogram */
struct rs_push_histogram push_histogram = {
- .devaddr_histograms = devaddr_histograms,
+ .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
.devaddr_keyvals = keyvals_even_addr,
.passes = passes,
};
radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
sizeof(push_histogram), &push_histogram);
- radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
+ vk_common_CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1);
+ }
- vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1);
+ /*
+ * Pipeline: PREFIX
+ *
+ * Launch one workgroup per pass.
+ */
+ vk_barrier_compute_w_to_compute_r(commandBuffer);
- /*
- * Pipeline: PREFIX
- *
- * Launch one workgroup per pass.
- */
- vk_barrier_compute_w_to_compute_r(commandBuffer);
+ radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
+
+ for (uint32_t i = 0; i < infoCount; ++i) {
+ if (!bvh_states[i].node_count)
+ continue;
+
+ uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
struct rs_push_prefix push_prefix = {
- .devaddr_histograms = devaddr_histograms,
+ .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
};
radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0,
sizeof(push_prefix), &push_prefix);
- radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
-
vk_common_CmdDispatch(commandBuffer, passes, 1, 1);
+ }
- /* Pipeline: SCATTER */
- vk_barrier_compute_w_to_compute_r(commandBuffer);
+ /* Pipeline: SCATTER */
+ vk_barrier_compute_w_to_compute_r(commandBuffer);
- uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
- uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset;
+ uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
- struct rs_push_scatter push_scatter = {
+ for (uint32_t i = 0; i < infoCount; i++) {
+ uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
+ uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
+ uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
+
+ bvh_states[i].push_scatter = (struct rs_push_scatter){
.devaddr_keyvals_even = keyvals_even_addr,
.devaddr_keyvals_odd = keyvals_odd_addr,
- .devaddr_partitions = devaddr_partitions,
- .devaddr_histograms = devaddr_histograms + histogram_offset,
- .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
+ .devaddr_partitions = internal_addr + rs->internal.partitions.offset,
+ .devaddr_histograms = internal_addr + rs->internal.histograms.offset + histogram_offset,
};
+ }
- {
- uint32_t pass_dword = pass_idx / 4;
+ bool is_even = true;
- radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even,
- VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter);
+ while (true) {
+ uint32_t pass_dword = pass_idx / 4;
- radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
- rs->pipelines.named.scatter[pass_dword].even);
- }
+ /* Bind new pipeline */
+ VkPipeline p =
+ is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
+ radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
- bool is_even = true;
+ /* Update push constants that changed */
+ VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
+ : rs->pipeline_layouts.named.scatter[pass_dword].odd;
- while (true) {
- vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1);
+ for (uint32_t i = 0; i < infoCount; i++) {
+ if (!bvh_states[i].node_count)
+ continue;
- /* Continue? */
- if (++pass_idx >= keyval_bytes)
- break;
+ bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
- vk_barrier_compute_w_to_compute_r(commandBuffer);
+ radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter),
+ &bvh_states[i].push_scatter);
- is_even ^= true;
- push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
- push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
+ vk_common_CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1);
- uint32_t pass_dword = pass_idx / 4;
+ bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
+ }
- /* Update push constants that changed */
- VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
- : rs->pipeline_layouts.named.scatter[pass_dword].odd;
- radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT,
- offsetof(struct rs_push_scatter, devaddr_histograms),
- sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
- &push_scatter.devaddr_histograms);
+ /* Continue? */
+ if (++pass_idx >= keyval_bytes)
+ break;
- /* Bind new pipeline */
- VkPipeline p =
- is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
+ vk_barrier_compute_w_to_compute_r(commandBuffer);
- radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
- }
+ is_even ^= true;
}
cmd_buffer->state.flush_bits |= flush_bits;