zink: unbind generated gs in `bind_last_vertex_stage`
authorantonino <antonino.maniscalco@collabora.com>
Thu, 9 Mar 2023 15:24:54 +0000 (16:24 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 29 Mar 2023 19:18:40 +0000 (19:18 +0000)
Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21238>

src/gallium/drivers/zink/zink_program.c

index f6f6f8f..4901f13 100644 (file)
@@ -1571,24 +1571,6 @@ zink_get_compute_pipeline(struct zink_screen *screen,
 }
 
 static void
-bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader);
-
-static void
-unbind_generated_gs(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader)
-{
-   if (ctx->gfx_stages[stage]->non_fs.is_generated)
-      ctx->inlinable_uniforms_valid_mask &= ~BITFIELD_BIT(MESA_SHADER_GEOMETRY);
-
-   if (ctx->gfx_stages[MESA_SHADER_GEOMETRY] &&
-       ctx->gfx_stages[MESA_SHADER_GEOMETRY]->non_fs.parent ==
-       ctx->gfx_stages[stage]) {
-      ctx->base.bind_gs_state(&ctx->base, NULL);
-      ctx->is_generated_gs_bound = false;
-      ctx->inlinable_uniforms_valid_mask &= ~BITFIELD_BIT(MESA_SHADER_GEOMETRY);
-   }
-}
-
-static void
 bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader)
 {
    if (shader && shader->nir->info.num_inlinable_uniforms)
@@ -1596,11 +1578,12 @@ bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shad
    else
       ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage);
 
-   if (ctx->gfx_stages[stage]) {
+   if (ctx->gfx_stages[stage])
       ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
 
-      if (stage != MESA_SHADER_FRAGMENT)
-         unbind_generated_gs(ctx, stage, shader);
+   if (!shader && stage == MESA_SHADER_GEOMETRY) {
+      ctx->inlinable_uniforms_valid_mask &= ~BITFIELD64_BIT(MESA_SHADER_GEOMETRY);
+      ctx->is_generated_gs_bound = false;
    }
 
    ctx->gfx_stages[stage] = shader;
@@ -1669,8 +1652,24 @@ update_rast_prim(struct zink_shader *shader)
 }
 
 static void
-bind_last_vertex_stage(struct zink_context *ctx)
+unbind_generated_gs(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *prev_shader)
+{
+   if (prev_shader->non_fs.is_generated)
+      ctx->inlinable_uniforms_valid_mask &= ~BITFIELD64_BIT(MESA_SHADER_GEOMETRY);
+
+   if (ctx->gfx_stages[MESA_SHADER_GEOMETRY] &&
+       ctx->gfx_stages[MESA_SHADER_GEOMETRY]->non_fs.parent ==
+       prev_shader) {
+      bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, NULL);
+   }
+}
+
+static void
+bind_last_vertex_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *prev_shader)
 {
+   if (prev_shader && stage < MESA_SHADER_GEOMETRY)
+      unbind_generated_gs(ctx, stage, prev_shader);
+
    gl_shader_stage old = ctx->last_vertex_stage ? ctx->last_vertex_stage->nir->info.stage : MESA_SHADER_STAGES;
    if (ctx->gfx_stages[MESA_SHADER_GEOMETRY])
       ctx->last_vertex_stage = ctx->gfx_stages[MESA_SHADER_GEOMETRY];
@@ -1724,8 +1723,9 @@ zink_bind_vs_state(struct pipe_context *pctx,
    struct zink_context *ctx = zink_context(pctx);
    if (!cso && !ctx->gfx_stages[MESA_SHADER_VERTEX])
       return;
+   struct zink_shader *prev_shader = ctx->gfx_stages[MESA_SHADER_VERTEX];
    bind_gfx_stage(ctx, MESA_SHADER_VERTEX, cso);
-   bind_last_vertex_stage(ctx);
+   bind_last_vertex_stage(ctx, MESA_SHADER_VERTEX, prev_shader);
    if (cso) {
       struct zink_shader *zs = cso;
       ctx->shader_reads_drawid = BITSET_TEST(zs->nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID);
@@ -1802,7 +1802,7 @@ zink_bind_gs_state(struct pipe_context *pctx,
    if (!cso && !ctx->gfx_stages[MESA_SHADER_GEOMETRY])
       return;
    bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, cso);
-   bind_last_vertex_stage(ctx);
+   bind_last_vertex_stage(ctx, MESA_SHADER_GEOMETRY, NULL);
 }
 
 static void
@@ -1826,8 +1826,9 @@ zink_bind_tes_state(struct pipe_context *pctx,
             ctx->gfx_stages[MESA_SHADER_TESS_CTRL] = NULL;
       }
    }
+   struct zink_shader *prev_shader = ctx->gfx_stages[MESA_SHADER_TESS_EVAL];
    bind_gfx_stage(ctx, MESA_SHADER_TESS_EVAL, cso);
-   bind_last_vertex_stage(ctx);
+   bind_last_vertex_stage(ctx, MESA_SHADER_TESS_EVAL, prev_shader);
 }
 
 static void *