nir: calculate number of vertices in nir_create_passthrough_gs
authorantonino <antonino.maniscalco@collabora.com>
Fri, 3 Feb 2023 10:23:09 +0000 (11:23 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 29 Mar 2023 19:18:40 +0000 (19:18 +0000)
`nir_create_passthrough_gs` has been changed to take the type of primitive
as opposed to the number of vertices as an argument.

Acked-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21238>

src/compiler/nir/nir.h
src/compiler/nir/nir_passthrough_gs.c
src/gallium/drivers/zink/zink_program.c

index 9228513..978f5a9 100644 (file)
@@ -5086,8 +5086,7 @@ nir_shader * nir_create_passthrough_tcs(const nir_shader_compiler_options *optio
                                         const nir_shader *vs, uint8_t patch_vertices);
 nir_shader * nir_create_passthrough_gs(const nir_shader_compiler_options *options,
                                        const nir_shader *prev_stage,
-                                       enum shader_prim primitive_type,
-                                       unsigned vertices);
+                                       enum shader_prim primitive_type);
 
 bool nir_lower_fragcolor(nir_shader *shader, unsigned max_cbufs);
 bool nir_lower_fragcoord_wtrans(nir_shader *shader);
index 82ff849..b9921e4 100644 (file)
 #include "nir.h"
 #include "nir_builder.h"
 
+static unsigned int
+gs_in_prim_for_topology(enum shader_prim prim)
+{
+   switch (prim) {
+   case SHADER_PRIM_QUADS:
+      return SHADER_PRIM_LINES_ADJACENCY;
+   default:
+      return prim;
+   }
+}
+
+static enum shader_prim
+gs_out_prim_for_topology(enum shader_prim prim)
+{
+   switch (prim) {
+   case SHADER_PRIM_POINTS:
+      return SHADER_PRIM_POINTS;
+   case SHADER_PRIM_LINES:
+   case SHADER_PRIM_LINE_LOOP:
+   case SHADER_PRIM_LINES_ADJACENCY:
+   case SHADER_PRIM_LINE_STRIP_ADJACENCY:
+   case SHADER_PRIM_LINE_STRIP:
+      return SHADER_PRIM_LINE_STRIP;
+   case SHADER_PRIM_TRIANGLES:
+   case SHADER_PRIM_TRIANGLE_STRIP:
+   case SHADER_PRIM_TRIANGLE_FAN:
+   case SHADER_PRIM_TRIANGLES_ADJACENCY:
+   case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
+   case SHADER_PRIM_POLYGON:
+      return SHADER_PRIM_TRIANGLE_STRIP;
+   case SHADER_PRIM_QUADS:
+   case SHADER_PRIM_QUAD_STRIP:
+   case SHADER_PRIM_PATCHES:
+   default:
+      return SHADER_PRIM_QUADS;
+   }
+}
+
+static unsigned int
+vertices_for_prim(enum shader_prim prim)
+{
+   switch (prim) {
+   case SHADER_PRIM_POINTS:
+      return 1;
+   case SHADER_PRIM_LINES:
+   case SHADER_PRIM_LINE_LOOP:
+   case SHADER_PRIM_LINES_ADJACENCY:
+   case SHADER_PRIM_LINE_STRIP_ADJACENCY:
+   case SHADER_PRIM_LINE_STRIP:
+      return 2;
+   case SHADER_PRIM_TRIANGLES:
+   case SHADER_PRIM_TRIANGLE_STRIP:
+   case SHADER_PRIM_TRIANGLE_FAN:
+   case SHADER_PRIM_TRIANGLES_ADJACENCY:
+   case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
+   case SHADER_PRIM_POLYGON:
+      return 3;
+   case SHADER_PRIM_QUADS:
+   case SHADER_PRIM_QUAD_STRIP:
+      return 4;
+   case SHADER_PRIM_PATCHES:
+   default:
+      unreachable("unsupported primitive for gs input");
+   }
+}
+
+static unsigned int
+array_size_for_prim(enum shader_prim prim)
+{
+   switch (prim) {
+   case SHADER_PRIM_POINTS:
+      return 1;
+   case SHADER_PRIM_LINES:
+   case SHADER_PRIM_LINE_LOOP:
+   case SHADER_PRIM_LINE_STRIP:
+      return 2;
+   case SHADER_PRIM_LINES_ADJACENCY:
+   case SHADER_PRIM_LINE_STRIP_ADJACENCY:
+      return 4;
+   case SHADER_PRIM_TRIANGLES:
+   case SHADER_PRIM_TRIANGLE_STRIP:
+   case SHADER_PRIM_TRIANGLE_FAN:
+   case SHADER_PRIM_POLYGON:
+      return 3;
+   case SHADER_PRIM_TRIANGLES_ADJACENCY:
+   case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
+      return 6;
+   case SHADER_PRIM_QUADS:
+   case SHADER_PRIM_QUAD_STRIP:
+      return 4;
+   case SHADER_PRIM_PATCHES:
+   default:
+      unreachable("unsupported primitive for gs input");
+   }
+}
+
 /*
  * A helper to create a passthrough GS shader for drivers that needs to lower
  * some rendering tasks to the GS.
 nir_shader *
 nir_create_passthrough_gs(const nir_shader_compiler_options *options,
                           const nir_shader *prev_stage,
-                          enum shader_prim primitive_type,
-                          unsigned vertices)
+                          enum shader_prim primitive_type)
 {
+   unsigned int vertices_out = vertices_for_prim(primitive_type);
    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
                                                   options,
                                                   "gs passthrough");
 
    nir_shader *nir = b.shader;
-   nir->info.gs.input_primitive = primitive_type;
-   nir->info.gs.output_primitive = primitive_type;
-   nir->info.gs.vertices_in = vertices;
-   nir->info.gs.vertices_out = vertices;
+   nir->info.gs.input_primitive = gs_in_prim_for_topology(primitive_type);
+   nir->info.gs.output_primitive = gs_out_prim_for_topology(primitive_type);
+   nir->info.gs.vertices_in = vertices_out;
+   nir->info.gs.vertices_out = vertices_out;
    nir->info.gs.invocations = 1;
    nir->info.gs.active_stream_mask = 1;
 
@@ -63,7 +159,7 @@ nir_create_passthrough_gs(const nir_shader_compiler_options *options,
 
       nir_variable *in = nir_variable_create(nir, nir_var_shader_in,
                                              glsl_array_type(var->type,
-                                                             vertices,
+                                                             array_size_for_prim(primitive_type),
                                                              false),
                                              name);
       in->data.location = var->data.location;
@@ -98,7 +194,7 @@ nir_create_passthrough_gs(const nir_shader_compiler_options *options,
    }
 
    unsigned int start_vert = 0;
-   unsigned int end_vert = vertices;
+   unsigned int end_vert = vertices_out;
    unsigned int vert_step = 1;
    switch (primitive_type) {
    case PIPE_PRIM_LINES_ADJACENCY:
index 9d4ae72..1037b94 100644 (file)
@@ -2219,8 +2219,7 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
             nir_shader *nir = nir_create_passthrough_gs(
                &screen->nir_options,
                ctx->gfx_stages[prev_vertex_stage]->nir,
-               (lower_line_stipple || lower_line_smooth) ? SHADER_PRIM_LINE_STRIP :  SHADER_PRIM_POINTS,
-               (lower_line_stipple || lower_line_smooth) ? 2 : 1);
+               (lower_line_stipple || lower_line_smooth) ? SHADER_PRIM_LINE_STRIP :  SHADER_PRIM_POINTS);
 
             struct zink_shader *shader = zink_shader_create(screen, nir, NULL);
             ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode] = shader;