From 1e0e4657f97cbf9ce080b4aa0fe01bc83eb8aa56 Mon Sep 17 00:00:00 2001 From: Lionel Landwerlin Date: Thu, 1 Dec 2022 17:09:22 +0200 Subject: [PATCH] spirv/nir: wire ray interection triangle position fetch Signed-off-by: Lionel Landwerlin Reviewed-by: Caio Oliveira Part-of: --- src/compiler/nir/nir.c | 4 +++ src/compiler/nir/nir.h | 1 + src/compiler/nir/nir_divergence_analysis.c | 1 + src/compiler/nir/nir_gather_info.c | 1 + src/compiler/nir/nir_intrinsics.py | 1 + src/compiler/nir/nir_lower_system_values.c | 14 +++++++++ src/compiler/shader_enums.c | 1 + src/compiler/shader_enums.h | 1 + src/compiler/shader_info.h | 1 + src/compiler/spirv/spirv_to_nir.c | 47 ++++++++++++++++++------------ src/compiler/spirv/vtn_variables.c | 4 +++ 11 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index a67abd4..6ba22c7 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -2486,6 +2486,8 @@ nir_intrinsic_from_system_value(gl_system_value val) return nir_intrinsic_load_ray_instance_custom_index; case SYSTEM_VALUE_CULL_MASK: return nir_intrinsic_load_cull_mask; + case SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS: + return nir_intrinsic_load_ray_triangle_vertex_positions; case SYSTEM_VALUE_MESH_VIEW_COUNT: return nir_intrinsic_load_mesh_view_count; case SYSTEM_VALUE_FRAG_SHADING_RATE: @@ -2639,6 +2641,8 @@ nir_system_value_from_intrinsic(nir_intrinsic_op intrin) return SYSTEM_VALUE_RAY_INSTANCE_CUSTOM_INDEX; case nir_intrinsic_load_cull_mask: return SYSTEM_VALUE_CULL_MASK; + case nir_intrinsic_load_ray_triangle_vertex_positions: + return SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS; case nir_intrinsic_load_frag_shading_rate: return SYSTEM_VALUE_FRAG_SHADING_RATE; case nir_intrinsic_load_mesh_view_count: diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index b8c1ee7..327da71 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -252,6 +252,7 @@ typedef enum { nir_ray_query_value_flags, nir_ray_query_value_world_ray_direction, nir_ray_query_value_world_ray_origin, + nir_ray_query_value_intersection_triangle_vertex_positions } nir_ray_query_value; typedef union { diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 0d9452e..ebc7ae0 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -697,6 +697,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr) case nir_intrinsic_report_ray_intersection: case nir_intrinsic_rq_proceed: case nir_intrinsic_rq_load: + case nir_intrinsic_load_ray_triangle_vertex_positions: is_divergent = true; break; diff --git a/src/compiler/nir/nir_gather_info.c b/src/compiler/nir/nir_gather_info.c index d170466..23d7724 100644 --- a/src/compiler/nir/nir_gather_info.c +++ b/src/compiler/nir/nir_gather_info.c @@ -758,6 +758,7 @@ gather_intrinsic_info(nir_intrinsic_instr *instr, nir_shader *shader, case nir_intrinsic_load_mesh_view_count: case nir_intrinsic_load_gs_header_ir3: case nir_intrinsic_load_tcs_header_ir3: + case nir_intrinsic_load_ray_triangle_vertex_positions: BITSET_SET(shader->info.system_values_read, nir_system_value_from_intrinsic(instr->intrinsic)); break; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 5a52503..1b0da45 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -893,6 +893,7 @@ system_value("ray_geometry_index", 1) system_value("ray_instance_custom_index", 1) system_value("shader_record_ptr", 1, bit_sizes=[64]) system_value("cull_mask", 1) +system_value("ray_triangle_vertex_positions", 3, indices=[COLUMN]) # Driver-specific viewport scale/offset parameters. # diff --git a/src/compiler/nir/nir_lower_system_values.c b/src/compiler/nir/nir_lower_system_values.c index b2e570c..6357ce9 100644 --- a/src/compiler/nir/nir_lower_system_values.c +++ b/src/compiler/nir/nir_lower_system_values.c @@ -148,6 +148,7 @@ lower_system_value_instr(nir_builder *b, nir_instr *instr, void *_state) case SYSTEM_VALUE_RAY_OBJECT_TO_WORLD: case SYSTEM_VALUE_RAY_WORLD_TO_OBJECT: case SYSTEM_VALUE_MESH_VIEW_INDICES: + case SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS: /* These are all single-element arrays in our implementation, and * the sysval load below just drops the 0 array index. */ @@ -250,6 +251,19 @@ lower_system_value_instr(nir_builder *b, nir_instr *instr, void *_state) assert(cols[i]->num_components == num_rows); } return nir_select_from_ssa_def_array(b, cols, num_cols, column); + } else if (glsl_type_is_array(var->type)) { + unsigned num_elems = glsl_get_length(var->type); + const struct glsl_type *elem_type = glsl_get_array_element(var->type); + assert(glsl_get_components(elem_type) == intrin->dest.ssa.num_components); + + nir_ssa_def *elems[4]; + assert(ARRAY_SIZE(elems) >= num_elems); + for (unsigned i = 0; i < num_elems; i++) { + elems[i] = nir_load_system_value(b, sysval_op, i, + intrin->dest.ssa.num_components, + intrin->dest.ssa.bit_size); + } + return nir_select_from_ssa_def_array(b, elems, num_elems, column); } else { return nir_load_system_value(b, sysval_op, 0, intrin->dest.ssa.num_components, diff --git a/src/compiler/shader_enums.c b/src/compiler/shader_enums.c index d144318..112196c 100644 --- a/src/compiler/shader_enums.c +++ b/src/compiler/shader_enums.c @@ -380,6 +380,7 @@ gl_system_value_name(gl_system_value sysval) ENUM(SYSTEM_VALUE_RAY_FLAGS), ENUM(SYSTEM_VALUE_RAY_GEOMETRY_INDEX), ENUM(SYSTEM_VALUE_CULL_MASK), + ENUM(SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS), ENUM(SYSTEM_VALUE_MESH_VIEW_COUNT), ENUM(SYSTEM_VALUE_MESH_VIEW_INDICES), ENUM(SYSTEM_VALUE_GS_HEADER_IR3), diff --git a/src/compiler/shader_enums.h b/src/compiler/shader_enums.h index a798678..993f2d7 100644 --- a/src/compiler/shader_enums.h +++ b/src/compiler/shader_enums.h @@ -841,6 +841,7 @@ typedef enum SYSTEM_VALUE_RAY_GEOMETRY_INDEX, SYSTEM_VALUE_RAY_INSTANCE_CUSTOM_INDEX, SYSTEM_VALUE_CULL_MASK, + SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS, /*@}*/ /** diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h index c511b40..35655bd 100644 --- a/src/compiler/shader_info.h +++ b/src/compiler/shader_info.h @@ -97,6 +97,7 @@ struct spirv_supported_capabilities { bool ray_query; bool ray_tracing; bool ray_traversal_primitive_culling; + bool ray_tracing_position_fetch; bool runtime_descriptor_array; bool shader_clock; bool shader_viewport_index_layer; diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 3dbb77f..19886bf 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4908,6 +4908,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, spv_check_supported(fragment_density, cap); break; + case SpvCapabilityRayTracingPositionFetchKHR: + case SpvCapabilityRayQueryPositionFetchKHR: + spv_check_supported(ray_tracing_position_fetch, cap); + break; + default: vtn_fail("Unhandled capability: %s (%u)", spirv_capability_to_string(cap), cap); @@ -5771,24 +5776,26 @@ spirv_to_nir_type_ray_query_intrinsic(struct vtn_builder *b, switch (opcode) { #define CASE(_spv, _nir, _type) case SpvOpRayQueryGet##_spv: \ return (struct ray_query_value) { .nir_value = nir_ray_query_value_##_nir, .glsl_type = _type } - CASE(RayTMinKHR, tmin, glsl_floatN_t_type(32)); - CASE(RayFlagsKHR, flags, glsl_uint_type()); - CASE(WorldRayDirectionKHR, world_ray_direction, glsl_vec_type(3)); - CASE(WorldRayOriginKHR, world_ray_origin, glsl_vec_type(3)); - CASE(IntersectionTypeKHR, intersection_type, glsl_uint_type()); - CASE(IntersectionTKHR, intersection_t, glsl_floatN_t_type(32)); - CASE(IntersectionInstanceCustomIndexKHR, intersection_instance_custom_index, glsl_int_type()); - CASE(IntersectionInstanceIdKHR, intersection_instance_id, glsl_int_type()); - CASE(IntersectionInstanceShaderBindingTableRecordOffsetKHR, intersection_instance_sbt_index, glsl_uint_type()); - CASE(IntersectionGeometryIndexKHR, intersection_geometry_index, glsl_int_type()); - CASE(IntersectionPrimitiveIndexKHR, intersection_primitive_index, glsl_int_type()); - CASE(IntersectionBarycentricsKHR, intersection_barycentrics, glsl_vec_type(2)); - CASE(IntersectionFrontFaceKHR, intersection_front_face, glsl_bool_type()); - CASE(IntersectionCandidateAABBOpaqueKHR, intersection_candidate_aabb_opaque, glsl_bool_type()); - CASE(IntersectionObjectToWorldKHR, intersection_object_to_world, glsl_matrix_type(glsl_get_base_type(glsl_float_type()), 3, 4)); - CASE(IntersectionWorldToObjectKHR, intersection_world_to_object, glsl_matrix_type(glsl_get_base_type(glsl_float_type()), 3, 4)); - CASE(IntersectionObjectRayOriginKHR, intersection_object_ray_origin, glsl_vec_type(3)); - CASE(IntersectionObjectRayDirectionKHR, intersection_object_ray_direction, glsl_vec_type(3)); + CASE(RayTMinKHR, tmin, glsl_floatN_t_type(32)); + CASE(RayFlagsKHR, flags, glsl_uint_type()); + CASE(WorldRayDirectionKHR, world_ray_direction, glsl_vec_type(3)); + CASE(WorldRayOriginKHR, world_ray_origin, glsl_vec_type(3)); + CASE(IntersectionTypeKHR, intersection_type, glsl_uint_type()); + CASE(IntersectionTKHR, intersection_t, glsl_floatN_t_type(32)); + CASE(IntersectionInstanceCustomIndexKHR, intersection_instance_custom_index, glsl_int_type()); + CASE(IntersectionInstanceIdKHR, intersection_instance_id, glsl_int_type()); + CASE(IntersectionInstanceShaderBindingTableRecordOffsetKHR, intersection_instance_sbt_index, glsl_uint_type()); + CASE(IntersectionGeometryIndexKHR, intersection_geometry_index, glsl_int_type()); + CASE(IntersectionPrimitiveIndexKHR, intersection_primitive_index, glsl_int_type()); + CASE(IntersectionBarycentricsKHR, intersection_barycentrics, glsl_vec_type(2)); + CASE(IntersectionFrontFaceKHR, intersection_front_face, glsl_bool_type()); + CASE(IntersectionCandidateAABBOpaqueKHR, intersection_candidate_aabb_opaque, glsl_bool_type()); + CASE(IntersectionObjectToWorldKHR, intersection_object_to_world, glsl_matrix_type(glsl_get_base_type(glsl_float_type()), 3, 4)); + CASE(IntersectionWorldToObjectKHR, intersection_world_to_object, glsl_matrix_type(glsl_get_base_type(glsl_float_type()), 3, 4)); + CASE(IntersectionObjectRayOriginKHR, intersection_object_ray_origin, glsl_vec_type(3)); + CASE(IntersectionObjectRayDirectionKHR, intersection_object_ray_direction, glsl_vec_type(3)); + CASE(IntersectionTriangleVertexPositionsKHR, intersection_triangle_vertex_positions, glsl_array_type(glsl_vec_type(3), 3, + glsl_get_explicit_stride(glsl_vec_type(3)))); #undef CASE default: vtn_fail_with_opcode("Unhandled opcode", opcode); @@ -5803,7 +5810,7 @@ ray_query_load_intrinsic_create(struct vtn_builder *b, SpvOp opcode, struct ray_query_value value = spirv_to_nir_type_ray_query_intrinsic(b, opcode); - if (glsl_type_is_matrix(value.glsl_type)) { + if (glsl_type_is_array_or_matrix(value.glsl_type)) { const struct glsl_type *elem_type = glsl_get_array_element(value.glsl_type); const unsigned elems = glsl_get_length(value.glsl_type); @@ -5879,6 +5886,7 @@ vtn_handle_ray_query_intrinsic(struct vtn_builder *b, SpvOp opcode, case SpvOpRayQueryGetIntersectionObjectRayOriginKHR: case SpvOpRayQueryGetIntersectionObjectToWorldKHR: case SpvOpRayQueryGetIntersectionWorldToObjectKHR: + case SpvOpRayQueryGetIntersectionTriangleVertexPositionsKHR: ray_query_load_intrinsic_create(b, opcode, w, vtn_ssa_value(b, w[3])->def, nir_i2b(&b->nb, vtn_ssa_value(b, w[4])->def)); @@ -6353,6 +6361,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpRayQueryGetWorldRayOriginKHR: case SpvOpRayQueryGetIntersectionObjectToWorldKHR: case SpvOpRayQueryGetIntersectionWorldToObjectKHR: + case SpvOpRayQueryGetIntersectionTriangleVertexPositionsKHR: vtn_handle_ray_query_intrinsic(b, opcode, w, count); break; diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 2a7ebcf..b0384a4 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -1182,6 +1182,10 @@ vtn_get_builtin_location(struct vtn_builder *b, *location = SYSTEM_VALUE_FRAG_INVOCATION_COUNT; set_mode_system_value(b, mode); break; + case SpvBuiltInHitTriangleVertexPositionsKHR: + *location = SYSTEM_VALUE_RAY_TRIANGLE_VERTEX_POSITIONS; + set_mode_system_value(b, mode); + break; default: vtn_fail("Unsupported builtin: %s (%u)", -- 2.7.4