nir/zink: handle provoking vertex mode in `nir_create_passthrough_gs`
authorantonino <antonino.maniscalco@collabora.com>
Mon, 20 Feb 2023 18:49:25 +0000 (19:49 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 29 Mar 2023 19:18:40 +0000 (19:18 +0000)
Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21238>

src/compiler/nir/nir.h
src/compiler/nir/nir_passthrough_gs.c
src/gallium/drivers/zink/zink_program.c

index b2eede8..d0c0175 100644 (file)
@@ -5088,6 +5088,7 @@ nir_shader * nir_create_passthrough_gs(const nir_shader_compiler_options *option
                                        const nir_shader *prev_stage,
                                        enum shader_prim primitive_type,
                                        int flat_interp_mask_offset,
+                                       int last_pv_vert_offset,
                                        bool emulate_edgeflags,
                                        bool force_line_strip_out);
 
index 6ffdfe0..fbe2ed6 100644 (file)
@@ -130,6 +130,7 @@ nir_create_passthrough_gs(const nir_shader_compiler_options *options,
                           const nir_shader *prev_stage,
                           enum shader_prim primitive_type,
                           int flat_interp_mask_offset,
+                          int last_pv_vert_offset,
                           bool emulate_edgeflags,
                           bool force_line_strip_out)
 {
@@ -221,6 +222,13 @@ nir_create_passthrough_gs(const nir_shader_compiler_options *options,
    nir_ssa_def *flat_interp_mask_def = nir_load_ubo(&b, 1, 32,
                                                     nir_imm_int(&b, 0), nir_imm_int(&b, flat_interp_mask_offset),
                                                     .align_mul = 4, .align_offset = 0, .range_base = 0, .range = ~0);
+   nir_ssa_def *last_pv_vert_def = nir_load_ubo(&b, 1, 32,
+                                                nir_imm_int(&b, 0), nir_imm_int(&b, last_pv_vert_offset),
+                                                .align_mul = 4, .align_offset = 0, .range_base = 0, .range = ~0);
+   last_pv_vert_def = nir_ine_imm(&b, last_pv_vert_def, 0);
+   nir_ssa_def *start_vert_index = nir_imm_int(&b, start_vert);
+   nir_ssa_def *end_vert_index = nir_imm_int(&b, end_vert - 1);
+   nir_ssa_def *pv_vert_index = nir_bcsel(&b, last_pv_vert_def, end_vert_index, start_vert_index);
    for (unsigned i = start_vert; i < end_vert || needs_closing; i += vert_step) {
       int idx = i < end_vert ? i : start_vert;
       /* Copy inputs to outputs. */
@@ -234,7 +242,7 @@ nir_create_passthrough_gs(const nir_shader_compiler_options *options,
             index = nir_imm_int(&b, idx);
          else {
             unsigned mask = 1u << (of++);
-            index = nir_bcsel(&b, nir_ieq_imm(&b, nir_iand_imm(&b, flat_interp_mask_def, mask), 0), nir_imm_int(&b, idx), nir_imm_int(&b, start_vert));
+            index = nir_bcsel(&b, nir_ieq_imm(&b, nir_iand_imm(&b, flat_interp_mask_def, mask), 0), nir_imm_int(&b, idx), pv_vert_index);
          }
          nir_ssa_def *value = nir_load_array_var(&b, in_vars[j], index);
          nir_store_var(&b, out_vars[oj], value,
index b11a01b..4e8cd89 100644 (file)
@@ -2315,11 +2315,13 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
                   ctx->gfx_stages[prev_vertex_stage]->nir,
                   ctx->gfx_pipeline_state.gfx_prim_mode,
                   ZINK_INLINE_VAL_FLAT_MASK * sizeof(uint32_t),
+                  ZINK_INLINE_VAL_PV_LAST_VERT * sizeof(uint32_t),
                   lower_edge_flags,
                   lower_line_stipple || lower_quad_prim);
             }
 
             zink_add_inline_uniform(nir, ZINK_INLINE_VAL_FLAT_MASK);
+            zink_add_inline_uniform(nir, ZINK_INLINE_VAL_PV_LAST_VERT);
             struct zink_shader *shader = zink_shader_create(screen, nir, &ctx->gfx_stages[prev_vertex_stage]->sinfo.so_info);
             shader->needs_inlining = true;
             ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode][zink_prim_type] = shader;
@@ -2332,8 +2334,9 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
          ctx->is_generated_gs_bound = true;
       }
 
-      ctx->base.set_inlinable_constants(&ctx->base, MESA_SHADER_GEOMETRY, 1,
-                                        (uint32_t []){zink_flat_flags(ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir)});
+      ctx->base.set_inlinable_constants(&ctx->base, MESA_SHADER_GEOMETRY, 2,
+                                        (uint32_t []){zink_flat_flags(ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir),
+                                                      ctx->gfx_pipeline_state.dyn_state3.pv_last});
    } else if (ctx->gfx_stages[MESA_SHADER_GEOMETRY] &&
               ctx->gfx_stages[MESA_SHADER_GEOMETRY]->non_fs.is_generated)
          bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, NULL);
@@ -2375,11 +2378,13 @@ zink_create_primitive_emulation_gs(struct zink_context *ctx)
                   ctx->gfx_stages[prev_vertex_stage]->nir,
                   ctx->gfx_pipeline_state.gfx_prim_mode,
                   ZINK_INLINE_VAL_FLAT_MASK * 4,
+                  ZINK_INLINE_VAL_PV_LAST_VERT * 4,
                   lower_edge_flags,
                   lower_quad_prim);
             }
 
             zink_add_inline_uniform(nir, ZINK_INLINE_VAL_FLAT_MASK);
+            zink_add_inline_uniform(nir, ZINK_INLINE_VAL_PV_LAST_VERT);
             struct zink_shader *shader = zink_shader_create(screen, nir, &ctx->gfx_stages[prev_vertex_stage]->sinfo.so_info);
             shader->needs_inlining = true;
             ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode][zink_prim_type] = shader;
@@ -2392,8 +2397,9 @@ zink_create_primitive_emulation_gs(struct zink_context *ctx)
          ctx->is_generated_gs_bound = true;
       }
 
-      ctx->base.set_inlinable_constants(&ctx->base, MESA_SHADER_GEOMETRY, 1,
-                                        (uint32_t []){zink_flat_flags(ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir)});
+      ctx->base.set_inlinable_constants(&ctx->base, MESA_SHADER_GEOMETRY, 2,
+                                        (uint32_t []){zink_flat_flags(ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir),
+                                                      ctx->gfx_pipeline_state.dyn_state3.pv_last});
    } else if (ctx->gfx_stages[MESA_SHADER_GEOMETRY] &&
               ctx->gfx_stages[MESA_SHADER_GEOMETRY]->non_fs.is_generated)
          bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, NULL);