radv: Build acceleration structures using BVH IR
authorFriedrich Vock <friedrich.vock@gmx.de>
Sat, 24 Sep 2022 20:10:02 +0000 (22:10 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 26 Sep 2022 22:25:22 +0000 (22:25 +0000)
Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18769>

src/amd/vulkan/bvh/build_helpers.h
src/amd/vulkan/bvh/build_interface.h
src/amd/vulkan/bvh/lbvh_internal.comp
src/amd/vulkan/bvh/leaf.comp
src/amd/vulkan/bvh/morton.comp
src/amd/vulkan/radv_acceleration_structure.c

index 9e37d40..6c226e0 100644 (file)
@@ -306,19 +306,19 @@ ir_type_to_bvh_type(uint32_t type)
 }
 
 AABB
-calculate_instance_node_bounds(radv_bvh_instance_node instance)
+calculate_instance_node_bounds(uint64_t base_ptr, mat3x4 otw_matrix)
 {
    AABB aabb;
-   radv_accel_struct_header header = DEREF(REF(radv_accel_struct_header)(instance.base_ptr));
+   radv_accel_struct_header header = DEREF(REF(radv_accel_struct_header)(base_ptr));
 
    for (uint32_t comp = 0; comp < 3; ++comp) {
-      aabb.min[comp] = instance.otw_matrix[comp][3];
-      aabb.max[comp] = instance.otw_matrix[comp][3];
+      aabb.min[comp] = otw_matrix[comp][3];
+      aabb.max[comp] = otw_matrix[comp][3];
       for (uint32_t col = 0; col < 3; ++col) {
-         aabb.min[comp] += min(instance.otw_matrix[comp][col] * header.aabb[0][col],
-                               instance.otw_matrix[comp][col] * header.aabb[1][col]);
-         aabb.max[comp] += max(instance.otw_matrix[comp][col] * header.aabb[0][col],
-                               instance.otw_matrix[comp][col] * header.aabb[1][col]);
+         aabb.min[comp] += min(otw_matrix[comp][col] * header.aabb[0][col],
+                               otw_matrix[comp][col] * header.aabb[1][col]);
+         aabb.max[comp] += max(otw_matrix[comp][col] * header.aabb[0][col],
+                               otw_matrix[comp][col] * header.aabb[1][col]);
       }
    }
    return aabb;
@@ -359,7 +359,7 @@ calculate_node_bounds(VOID_REF bvh, uint32_t id)
    }
    case radv_bvh_node_instance: {
       radv_bvh_instance_node instance = DEREF(REF(radv_bvh_instance_node)(node));
-      aabb = calculate_instance_node_bounds(instance);
+      aabb = calculate_instance_node_bounds(instance.base_ptr, instance.otw_matrix);
       break;
    }
    case radv_bvh_node_aabb: {
index 6a3e570..aaf6157 100644 (file)
@@ -62,7 +62,7 @@ struct lbvh_internal_args {
    REF(key_id_pair) src_ids;
    REF(key_id_pair) dst_ids;
    uint32_t dst_offset;
-   uint32_t fill_count;
+   uint32_t is_root_and_count;
 };
 
 #define RADV_COPY_MODE_COPY        0
index c8d6fc7..405bc3f 100644 (file)
@@ -46,22 +46,23 @@ main(void)
 {
    uint32_t global_id = gl_GlobalInvocationID.x;
 
-   bool fill_header = (args.fill_count & 0x80000000u) != 0;
-   uint32_t src_count = args.fill_count & 0x7FFFFFFFu;
+   bool is_root_node = (args.is_root_and_count & 0x80000000u) != 0;
+   uint32_t src_count = args.is_root_and_count & 0x7FFFFFFFu;
 
-   uint32_t src_index = global_id * 4;
-   uint32_t child_count = min(src_count - src_index, 4);
+   uint32_t src_index = global_id * 2;
+   uint32_t child_count = min(src_count - src_index, 2);
 
-   uint32_t dst_offset = args.dst_offset + global_id * SIZEOF(radv_bvh_box32_node);
+   uint32_t dst_offset = args.dst_offset + global_id * SIZEOF(radv_ir_box_node);
+   uint32_t current_id = pack_ir_node_id(dst_offset, radv_ir_node_internal);
 
-   REF(radv_bvh_box32_node) dst_node = REF(radv_bvh_box32_node)(OFFSET(args.bvh, dst_offset));
+   REF(radv_ir_box_node) dst_node = REF(radv_ir_box_node)(OFFSET(args.bvh, dst_offset));
 
    AABB total_bounds;
    total_bounds.min = vec3(INFINITY);
    total_bounds.max = vec3(-INFINITY);
 
    bool is_active = false;
-   for (uint32_t i = 0; i < 4; i++) {
+   for (uint32_t i = 0; i < 2; i++) {
       AABB bounds;
       bounds.min = vec3(NAN);
       bounds.max = vec3(NAN);
@@ -69,38 +70,35 @@ main(void)
       uint32_t child_id = DEREF(INDEX(key_id_pair, args.src_ids, src_index + i)).id;
 
       if (i < child_count && child_id != NULL_NODE_ID) {
-         bounds = calculate_node_bounds(args.bvh, child_id);
+         VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(child_id));
+         REF(radv_ir_node) child = REF(radv_ir_node)(node);
+         bounds = load_aabb(child);
+
          total_bounds.min = min(total_bounds.min, bounds.min);
          total_bounds.max = max(total_bounds.max, bounds.max);
          is_active = true;
+
+         DEREF(child).parent = current_id;
       } else {
          child_id = NULL_NODE_ID;
       }
 
       DEREF(dst_node).children[i] = child_id;
-
-      DEREF(dst_node).coords[i][0][0] = bounds.min.x;
-      DEREF(dst_node).coords[i][0][1] = bounds.min.y;
-      DEREF(dst_node).coords[i][0][2] = bounds.min.z;
-      DEREF(dst_node).coords[i][1][0] = bounds.max.x;
-      DEREF(dst_node).coords[i][1][1] = bounds.max.y;
-      DEREF(dst_node).coords[i][1][2] = bounds.max.z;
    }
+   DEREF(dst_node).base.aabb[0][0] = total_bounds.min.x;
+   DEREF(dst_node).base.aabb[0][1] = total_bounds.min.y;
+   DEREF(dst_node).base.aabb[0][2] = total_bounds.min.z;
+   DEREF(dst_node).base.aabb[1][0] = total_bounds.max.x;
+   DEREF(dst_node).base.aabb[1][1] = total_bounds.max.y;
+   DEREF(dst_node).base.aabb[1][2] = total_bounds.max.z;
 
-   uint32_t node_id = pack_node_id(dst_offset, radv_bvh_node_internal);
    /* An internal node is considered inactive if it has no children. Set the resulting scratch node
     * id to NULL_NODE_ID for more internal nodes to become inactive.
     */
-   DEREF(INDEX(key_id_pair, args.dst_ids, global_id)).id = is_active ? node_id : NULL_NODE_ID;
-
-   if (fill_header) {
-      REF(radv_accel_struct_header) header = REF(radv_accel_struct_header)(args.bvh);
+   DEREF(INDEX(key_id_pair, args.dst_ids, global_id)).id = is_active ? current_id : NULL_NODE_ID;
 
-      DEREF(header).aabb[0][0] = total_bounds.min.x;
-      DEREF(header).aabb[0][1] = total_bounds.min.y;
-      DEREF(header).aabb[0][2] = total_bounds.min.z;
-      DEREF(header).aabb[1][0] = total_bounds.max.x;
-      DEREF(header).aabb[1][1] = total_bounds.max.y;
-      DEREF(header).aabb[1][2] = total_bounds.max.z;
+   if (is_root_node) {
+      /* We're writing the root node, which doesn't have a parent */
+      DEREF(dst_node).base.parent = 0xFFFFFFFF;
    }
 }
index 252c474..45799ea 100644 (file)
@@ -185,30 +185,31 @@ TYPE(AccelerationStructureInstance, 8);
 bool
 build_instance(inout AABB bounds, VOID_REF src_ptr, VOID_REF dst_ptr, uint32_t global_id)
 {
-   REF(radv_bvh_instance_node) node = REF(radv_bvh_instance_node)(dst_ptr);
+   REF(radv_ir_instance_node) node = REF(radv_ir_instance_node)(dst_ptr);
 
    AccelerationStructureInstance instance = DEREF(REF(AccelerationStructureInstance)(src_ptr));
    if (instance.accelerationStructureReference == 0)
       return false;
 
-   mat4 transform = mat4(instance.transform);
-
-   /* We store everything as mat3x4 for layout reasons but the conceptual matrix
-    * is really a mat4x3. So transpose it temporarily for the invertion. */
-   mat4 inv_transform = transpose(inverse(transpose(transform)));
-   DEREF(node).wto_matrix = mat3x4(inv_transform);
-   DEREF(node).otw_matrix = mat3x4(transform);
+   DEREF(node).otw_matrix = instance.transform;
 
    radv_accel_struct_header instance_header =
       DEREF(REF(radv_accel_struct_header)(instance.accelerationStructureReference));
    DEREF(node).base_ptr = instance.accelerationStructureReference;
 
-   bounds = calculate_instance_node_bounds(DEREF(node));
+   bounds = calculate_instance_node_bounds(DEREF(node).base_ptr, DEREF(node).otw_matrix);
 
    DEREF(node).custom_instance_and_mask = instance.custom_instance_and_mask;
    DEREF(node).sbt_offset_and_flags = instance.sbt_offset_and_flags;
    DEREF(node).instance_id = global_id;
 
+   DEREF(node).base.aabb[0][0] = bounds.min.x;
+   DEREF(node).base.aabb[0][1] = bounds.min.y;
+   DEREF(node).base.aabb[0][2] = bounds.min.z;
+   DEREF(node).base.aabb[1][0] = bounds.max.x;
+   DEREF(node).base.aabb[1][1] = bounds.max.y;
+   DEREF(node).base.aabb[1][2] = bounds.max.z;
+
    return true;
 }
 
@@ -223,14 +224,14 @@ main(void)
    uint32_t dst_stride;
    uint32_t node_type;
    if (args.geometry_type == VK_GEOMETRY_TYPE_TRIANGLES_KHR) {
-      dst_stride = SIZEOF(radv_bvh_triangle_node);
-      node_type = radv_bvh_node_triangle;
+      dst_stride = SIZEOF(radv_ir_triangle_node);
+      node_type = radv_ir_node_triangle;
    } else if (args.geometry_type == VK_GEOMETRY_TYPE_AABBS_KHR) {
-      dst_stride = SIZEOF(radv_bvh_aabb_node);
-      node_type = radv_bvh_node_aabb;
+      dst_stride = SIZEOF(radv_ir_aabb_node);
+      node_type = radv_ir_node_aabb;
    } else {
-      dst_stride = SIZEOF(radv_bvh_instance_node);
-      node_type = radv_bvh_node_instance;
+      dst_stride = SIZEOF(radv_ir_instance_node);
+      node_type = radv_ir_node_instance;
    }
 
    uint32_t dst_offset = args.dst_offset + global_id * dst_stride;
@@ -255,7 +256,7 @@ main(void)
             vertices.vertex[i] = transform * vertices.vertex[i];
       }
 
-      REF(radv_bvh_triangle_node) node = REF(radv_bvh_triangle_node)(dst_ptr);
+      REF(radv_ir_triangle_node) node = REF(radv_ir_triangle_node)(dst_ptr);
 
       bounds.min = vec3(INFINITY);
       bounds.max = vec3(-INFINITY);
@@ -267,6 +268,13 @@ main(void)
             bounds.max[comp] = max(bounds.max[comp], vertices.vertex[coord][comp]);
          }
 
+      DEREF(node).base.aabb[0][0] = bounds.min.x;
+      DEREF(node).base.aabb[0][1] = bounds.min.y;
+      DEREF(node).base.aabb[0][2] = bounds.min.z;
+      DEREF(node).base.aabb[1][0] = bounds.max.x;
+      DEREF(node).base.aabb[1][1] = bounds.max.y;
+      DEREF(node).base.aabb[1][2] = bounds.max.z;
+
       DEREF(node).triangle_id = global_id;
       DEREF(node).geometry_id_and_flags = args.geometry_id;
       DEREF(node).id = 9;
@@ -274,12 +282,12 @@ main(void)
    } else if (args.geometry_type == VK_GEOMETRY_TYPE_AABBS_KHR) {
       VOID_REF src_ptr = OFFSET(args.data, src_offset);
 
-      REF(radv_bvh_aabb_node) node = REF(radv_bvh_aabb_node)(dst_ptr);
+      REF(radv_ir_aabb_node) node = REF(radv_ir_aabb_node)(dst_ptr);
 
       for (uint32_t vec = 0; vec < 2; vec++)
          for (uint32_t comp = 0; comp < 3; comp++) {
             float coord = DEREF(INDEX(float, src_ptr, comp + vec * 3));
-            DEREF(node).aabb[vec][comp] = coord;
+            DEREF(node).base.aabb[vec][comp] = coord;
 
             if (vec == 0)
                bounds.min[comp] = coord;
@@ -299,7 +307,7 @@ main(void)
       is_active = build_instance(bounds, src_ptr, dst_ptr, global_id);
    }
 
-   DEREF(id_ptr).id = is_active ? pack_node_id(dst_offset, node_type) : NULL_NODE_ID;
+   DEREF(id_ptr).id = is_active ? pack_ir_node_id(dst_offset, node_type) : NULL_NODE_ID;
 
    min_float_emulated(INDEX(int32_t, args.bounds, 0), bounds.min.x);
    min_float_emulated(INDEX(int32_t, args.bounds, 1), bounds.min.y);
index db0a843..578b05e 100644 (file)
@@ -73,7 +73,7 @@ main(void)
 
    uint32_t key;
    if (id != NULL_NODE_ID) {
-      AABB bounds = calculate_node_bounds(args.bvh, id);
+      AABB bounds = load_aabb(REF(radv_ir_node)OFFSET(args.bvh, ir_id_to_offset(id)));
       vec3 center = (bounds.min + bounds.max) * 0.5;
 
       AABB bvh_bounds;
index 390a01f..7e0d190 100644 (file)
@@ -102,7 +102,7 @@ radv_GetAccelerationStructureBuildSizesKHR(
    /* Initialize to 1 to have enought space for the root node. */
    uint64_t internal_nodes = 1;
    while (children > 1) {
-      children = DIV_ROUND_UP(children, 4);
+      children = DIV_ROUND_UP(children, 2);
       internal_nodes += children;
    }
 
@@ -126,6 +126,11 @@ radv_GetAccelerationStructureBuildSizesKHR(
 
    scratchSize += requirements.internal_size + SCRATCH_TOTAL_BOUNDS_SIZE;
 
+   /* IR leaf nodes */
+   scratchSize += boxes * sizeof(struct radv_ir_aabb_node) + instances * sizeof(struct radv_ir_instance_node) + triangles * sizeof(struct radv_ir_triangle_node);
+   /* IR internal nodes */
+   scratchSize += internal_nodes * sizeof(struct radv_ir_box_node);
+
    scratchSize = MAX2(4096, scratchSize);
    pSizeInfo->updateScratchSize = scratchSize;
    pSizeInfo->buildScratchSize = scratchSize;
@@ -386,15 +391,17 @@ radv_device_init_accel_struct_build_state(struct radv_device *device)
 }
 
 struct bvh_state {
+   VkDeviceAddress bvh_ir;
+
    uint32_t node_offset;
    uint32_t node_count;
    uint32_t scratch_offset;
    uint32_t buffer_1_offset;
    uint32_t buffer_2_offset;
 
-   uint32_t leaf_node_offset;
    uint32_t leaf_node_count;
    uint32_t internal_node_count;
+   uint32_t leaf_node_size;
 };
 
 static void
@@ -407,18 +414,12 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
    radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
                         cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
    for (uint32_t i = 0; i < infoCount; ++i) {
-      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
-                       pInfos[i].dstAccelerationStructure);
-
       struct leaf_args leaf_consts = {
-         .bvh = accel_struct->va,
+         .bvh = bvh_states[i].bvh_ir,
          .bounds = pInfos[i].scratchData.deviceAddress,
          .ids = pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE,
-         .dst_offset =
-            ALIGN(sizeof(struct radv_accel_struct_header), 64) + sizeof(struct radv_bvh_box32_node),
+         .dst_offset = 0,
       };
-      bvh_states[i].node_offset = leaf_consts.dst_offset;
-      bvh_states[i].leaf_node_offset = leaf_consts.dst_offset;
 
       for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
          const VkAccelerationStructureGeometryKHR *geom =
@@ -431,6 +432,7 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
          leaf_consts.geometry_type = geom->geometryType;
          leaf_consts.geometry_id = j | (geom->flags << 28);
          unsigned prim_size;
+         unsigned output_prim_size;
          switch (geom->geometryType) {
          case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
             assert(pInfos[i].type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
@@ -452,7 +454,8 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
             leaf_consts.vertex_format = geom->geometry.triangles.vertexFormat;
             leaf_consts.index_format = geom->geometry.triangles.indexType;
 
-            prim_size = sizeof(struct radv_bvh_triangle_node);
+            prim_size = sizeof(struct radv_ir_triangle_node);
+            output_prim_size = sizeof(struct radv_bvh_triangle_node);
             break;
          case VK_GEOMETRY_TYPE_AABBS_KHR:
             assert(pInfos[i].type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
@@ -461,7 +464,8 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
                geom->geometry.aabbs.data.deviceAddress + buildRangeInfo->primitiveOffset;
             leaf_consts.stride = geom->geometry.aabbs.stride;
 
-            prim_size = sizeof(struct radv_bvh_aabb_node);
+            prim_size = sizeof(struct radv_ir_aabb_node);
+            output_prim_size = sizeof(struct radv_bvh_aabb_node);
             break;
          case VK_GEOMETRY_TYPE_INSTANCES_KHR:
             assert(pInfos[i].type == VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR);
@@ -474,7 +478,8 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
             else
                leaf_consts.stride = sizeof(VkAccelerationStructureInstanceKHR);
 
-            prim_size = sizeof(struct radv_bvh_instance_node);
+            prim_size = sizeof(struct radv_ir_instance_node);
+            output_prim_size = sizeof(struct radv_bvh_instance_node);
             break;
          default:
             unreachable("Unknown geometryType");
@@ -489,6 +494,7 @@ build_leaves(VkCommandBuffer commandBuffer, uint32_t infoCount,
 
          bvh_states[i].leaf_node_count += buildRangeInfo->primitiveCount;
          bvh_states[i].node_count += buildRangeInfo->primitiveCount;
+         bvh_states[i].leaf_node_size = output_prim_size;
       }
       bvh_states[i].node_offset = leaf_consts.dst_offset;
    }
@@ -506,11 +512,8 @@ morton_generate(VkCommandBuffer commandBuffer, uint32_t infoCount,
                         cmd_buffer->device->meta_state.accel_struct_build.morton_pipeline);
 
    for (uint32_t i = 0; i < infoCount; ++i) {
-      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
-                       pInfos[i].dstAccelerationStructure);
-
       const struct morton_args consts = {
-         .bvh = accel_struct->va,
+         .bvh = bvh_states[i].bvh_ir,
          .bounds = pInfos[i].scratchData.deviceAddress,
          .ids = pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE,
       };
@@ -584,9 +587,6 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount,
    for (unsigned iter = 0; progress; ++iter) {
       progress = false;
       for (uint32_t i = 0; i < infoCount; ++i) {
-         RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
-                          pInfos[i].dstAccelerationStructure);
-
          if (iter && bvh_states[i].node_count == 1)
             continue;
 
@@ -595,7 +595,7 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount,
 
          progress = true;
 
-         uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
+         uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 2));
          bool final_iter = dst_node_count == 1;
 
          uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
@@ -605,34 +605,97 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount,
             (src_scratch_offset == buffer_1_offset) ? buffer_2_offset : buffer_1_offset;
 
          uint32_t dst_node_offset = bvh_states[i].node_offset;
-         if (final_iter) {
-            dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
 
-            /* Make sure we build the BVH so the hardcoded root node is valid. */
-            STATIC_ASSERT(RADV_BVH_ROOT_NODE ==
-                          DIV_ROUND_UP(sizeof(struct radv_accel_struct_header), 64) * 8 +
-                             radv_bvh_node_internal);
-         }
+         /* Make sure we build the BVH so the hardcoded root node is valid. */
+         STATIC_ASSERT(RADV_BVH_ROOT_NODE ==
+                       DIV_ROUND_UP(sizeof(struct radv_accel_struct_header), 64) * 8 +
+                          radv_bvh_node_internal);
 
          const struct lbvh_internal_args consts = {
-            .bvh = accel_struct->va,
+            .bvh = bvh_states[i].bvh_ir,
             .src_ids = pInfos[i].scratchData.deviceAddress + src_scratch_offset,
             .dst_ids = pInfos[i].scratchData.deviceAddress + dst_scratch_offset,
             .dst_offset = dst_node_offset,
-            .fill_count = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
+            .is_root_and_count = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
          };
 
          radv_CmdPushConstants(
             commandBuffer, cmd_buffer->device->meta_state.accel_struct_build.lbvh_internal_p_layout,
             VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
          radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
-         if (!final_iter)
-            bvh_states[i].node_offset += dst_node_count * 128;
+         bvh_states[i].node_offset += dst_node_count * sizeof(struct radv_ir_box_node);
          bvh_states[i].node_count = dst_node_count;
          bvh_states[i].internal_node_count += dst_node_count;
          bvh_states[i].scratch_offset = dst_scratch_offset;
       }
    }
+   cmd_buffer->state.flush_bits |= flush_bits;
+}
+
+static void
+convert_leaf_nodes(VkCommandBuffer commandBuffer, uint32_t infoCount,
+                   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
+                   struct bvh_state *bvh_states)
+{
+   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
+   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
+                        cmd_buffer->device->meta_state.accel_struct_build.convert_leaf_pipeline);
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      if (!pInfos[i].geometryCount)
+         continue;
+
+      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
+                       pInfos[i].dstAccelerationStructure);
+
+      const struct convert_leaf_args args = {
+         .intermediate_bvh = bvh_states[i].bvh_ir,
+         .output_bvh = accel_struct->va,
+         .geometry_type = pInfos->pGeometries ? pInfos->pGeometries[0].geometryType
+                                              : pInfos->ppGeometries[0]->geometryType,
+      };
+      radv_CmdPushConstants(commandBuffer,
+                            cmd_buffer->device->meta_state.accel_struct_build.convert_leaf_p_layout,
+                            VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(args), &args);
+      radv_unaligned_dispatch(cmd_buffer, bvh_states[i].leaf_node_count, 1, 1);
+   }
+   /* This is the final access to the leaf nodes, no need to flush */
+}
+
+static void
+convert_internal_nodes(VkCommandBuffer commandBuffer, uint32_t infoCount,
+                       const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
+                       struct bvh_state *bvh_states)
+{
+   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
+   radv_CmdBindPipeline(
+      commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
+      cmd_buffer->device->meta_state.accel_struct_build.convert_internal_pipeline);
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
+                       pInfos[i].dstAccelerationStructure);
+
+      VkGeometryTypeKHR geometry_type = VK_GEOMETRY_TYPE_TRIANGLES_KHR;
+
+      /* If the geometry count is 0, then the size does not matter
+       * because it will be multiplied with 0.
+       */
+      if (pInfos[i].geometryCount)
+         geometry_type = pInfos->pGeometries ? pInfos->pGeometries[0].geometryType
+                                             : pInfos->ppGeometries[0]->geometryType;
+
+      const struct convert_internal_args args = {
+         .intermediate_bvh = bvh_states[i].bvh_ir,
+         .output_bvh = accel_struct->va,
+         .leaf_node_count = bvh_states[i].leaf_node_count,
+         .internal_node_count = bvh_states[i].internal_node_count,
+         .geometry_type = geometry_type,
+      };
+      radv_CmdPushConstants(
+         commandBuffer, cmd_buffer->device->meta_state.accel_struct_build.convert_internal_p_layout,
+         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(args), &args);
+      radv_unaligned_dispatch(cmd_buffer, bvh_states[i].internal_node_count, 1, 1);
+   }
+   /* This is the final access to the leaf nodes, no need to flush */
 }
 
 VKAPI_ATTR void VKAPI_CALL
@@ -666,6 +729,24 @@ radv_CmdBuildAccelerationStructuresKHR(
 
    cmd_buffer->state.flush_bits |= flush_bits;
 
+   for (uint32_t i = 0; i < infoCount; ++i) {
+      uint32_t leaf_node_count = 0;
+      for (uint32_t j = 0; j < pInfos[i].geometryCount; ++j) {
+         leaf_node_count += ppBuildRangeInfos[i][j].primitiveCount;
+      }
+
+      radix_sort_vk_memory_requirements_t requirements;
+      radix_sort_vk_get_memory_requirements(
+         cmd_buffer->device->meta_state.accel_struct_build.radix_sort, leaf_node_count,
+         &requirements);
+
+      /* Calculate size of other scratch metadata */
+      VkDeviceSize bvh_ir_offset = requirements.internal_size + SCRATCH_TOTAL_BOUNDS_SIZE;
+      bvh_ir_offset += 2 * MAX2(leaf_node_count * KEY_ID_PAIR_SIZE, requirements.keyvals_size);
+
+      bvh_states[i].bvh_ir = pInfos[i].scratchData.deviceAddress + bvh_ir_offset;
+   }
+
    build_leaves(commandBuffer, infoCount, pInfos, ppBuildRangeInfos, bvh_states, flush_bits);
 
    morton_generate(commandBuffer, infoCount, pInfos, bvh_states, flush_bits);
@@ -674,6 +755,10 @@ radv_CmdBuildAccelerationStructuresKHR(
 
    lbvh_build_internal(commandBuffer, infoCount, pInfos, bvh_states, flush_bits);
 
+   convert_leaf_nodes(commandBuffer, infoCount, pInfos, bvh_states);
+
+   convert_internal_nodes(commandBuffer, infoCount, pInfos, bvh_states);
+
    for (uint32_t i = 0; i < infoCount; ++i) {
       RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
                        pInfos[i].dstAccelerationStructure);
@@ -685,9 +770,14 @@ radv_CmdBuildAccelerationStructuresKHR(
       uint64_t geometry_infos_size =
          pInfos[i].geometryCount * sizeof(struct radv_accel_struct_geometry_info);
 
-      header.instance_offset = bvh_states[i].leaf_node_offset;
+      header.instance_offset =
+         align(sizeof(struct radv_accel_struct_header), 64) + sizeof(struct radv_bvh_box32_node);
       header.instance_count = is_tlas ? bvh_states[i].leaf_node_count : 0;
-      header.compacted_size = bvh_states[i].node_offset + geometry_infos_size;
+      header.compacted_size =
+         align(sizeof(struct radv_accel_struct_header), 64) +
+         bvh_states[i].leaf_node_count * bvh_states[i].leaf_node_size +
+         bvh_states[i].internal_node_count * sizeof(struct radv_bvh_box32_node) +
+         geometry_infos_size;
 
       header.copy_dispatch_size[0] = DIV_ROUND_UP(header.compacted_size, 16 * 64);
       header.copy_dispatch_size[1] = 1;
@@ -722,11 +812,13 @@ radv_CmdBuildAccelerationStructuresKHR(
                             radv_buffer_get_va(accel_struct->bo) + accel_struct->mem_offset + base,
                             (const char *)&header + base, sizeof(header) - base);
 
+      VkDeviceSize geometry_infos_offset = header.compacted_size - geometry_infos_size;
+
       struct radv_buffer accel_struct_buffer;
       radv_buffer_init(&accel_struct_buffer, cmd_buffer->device, accel_struct->bo,
                        accel_struct->size, accel_struct->mem_offset);
       radv_CmdUpdateBuffer(commandBuffer, radv_buffer_to_handle(&accel_struct_buffer),
-                           bvh_states[i].node_offset, geometry_infos_size, geometry_infos);
+                           geometry_infos_offset, geometry_infos_size, geometry_infos);
       radv_buffer_finish(&accel_struct_buffer);
 
       free(geometry_infos);