zink: store nir as serialized on zink_shader structs
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Tue, 28 Mar 2023 22:52:31 +0000 (18:52 -0400)
committerMarge Bot <emma+marge@anholt.net>
Tue, 4 Apr 2023 01:37:41 +0000 (01:37 +0000)
nir_shader objects are hefty, and they really add up when there's a lot
of them. there's also not much use in keeping them around, as any time
they'll be used, they're always cloned first, and deserializing isn't
likely to be any slower than a clone

cuts driver memory usage by ~40% for tomb raider

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22266>

src/gallium/drivers/zink/zink_compiler.c
src/gallium/drivers/zink/zink_compiler.h
src/gallium/drivers/zink/zink_program.c
src/gallium/drivers/zink/zink_types.h

index ae4b25a..442cbb2 100644 (file)
@@ -35,6 +35,7 @@
 #include "nir_xfb_info.h"
 #include "nir/nir_draw_helpers.h"
 #include "compiler/nir/nir_builder.h"
+#include "compiler/nir/nir_serialize.h"
 #include "compiler/nir/nir_builtin_builder.h"
 
 #include "nir/tgsi_to_nir.h"
@@ -3615,7 +3616,7 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs,
 VkShaderModule
 zink_shader_compile_separate(struct zink_screen *screen, struct zink_shader *zs)
 {
-   nir_shader *nir = nir_shader_clone(NULL, zs->nir);
+   nir_shader *nir = zink_shader_deserialize(screen, zs);
    int set = nir->info.stage == MESA_SHADER_FRAGMENT;
    unsigned offsets[4];
    zink_descriptor_shader_get_binding_offsets(zs, offsets);
@@ -4892,7 +4893,6 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
    if (nir->info.stage != MESA_SHADER_KERNEL)
       NIR_PASS_V(nir, match_tex_dests, ret);
 
-   ret->nir = nir;
    if (!nir->info.internal)
       nir_foreach_shader_out_variable(var, nir)
          var->data.explicit_xfb_buffer = 0;
@@ -4915,6 +4915,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
          NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_shader_temp, NULL);
       }
    }
+   zink_shader_serialize_blob(nir, &ret->blob);
    memcpy(&ret->info, &nir->info, sizeof(nir->info));
 
    ret->can_inline = true;
@@ -5039,7 +5040,7 @@ zink_shader_free(struct zink_screen *screen, struct zink_shader *shader)
       VKSCR(DestroyShaderModule)(screen->dev, shader->precompile.mod, NULL);
    if (shader->precompile.gpl)
       VKSCR(DestroyPipeline)(screen->dev, shader->precompile.gpl, NULL);
-   ralloc_free(shader->nir);
+   blob_finish(&shader->blob);
    ralloc_free(shader->spirv);
    free(shader->precompile.bindings);
    ralloc_free(shader);
@@ -5156,8 +5157,8 @@ zink_shader_tcs_create(struct zink_screen *screen, nir_shader *vs, unsigned vert
    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL);
    NIR_PASS_V(nir, nir_convert_from_ssa, true);
 
-   ret->nir = nir;
    *nir_ret = nir;
+   zink_shader_serialize_blob(nir, &ret->blob);
    memcpy(&ret->info, &nir->info, sizeof(nir->info));
    ret->non_fs.is_generated = true;
    return ret;
@@ -5173,3 +5174,23 @@ zink_shader_has_cubes(nir_shader *nir)
    }
    return false;
 }
+
+nir_shader *
+zink_shader_deserialize(struct zink_screen *screen, struct zink_shader *zs)
+{
+   struct blob_reader blob;
+   blob_reader_init(&blob, zs->blob.data, zs->blob.size);
+   return nir_deserialize(NULL, &screen->nir_options, &blob);
+}
+
+void
+zink_shader_serialize_blob(nir_shader *nir, struct blob *blob)
+{
+   blob_init(blob);
+#ifndef NDEBUG
+   bool strip = !(zink_debug & (ZINK_DEBUG_NIR | ZINK_DEBUG_SPIRV | ZINK_DEBUG_TGSI));
+#else
+   bool strip = false;
+#endif
+   nir_serialize(blob, nir, strip);
+}
index 0464677..574db1b 100644 (file)
@@ -96,4 +96,8 @@ zink_shader_descriptor_is_buffer(struct zink_shader *zs, enum zink_descriptor_ty
 
 bool
 zink_shader_has_cubes(nir_shader *nir);
+nir_shader *
+zink_shader_deserialize(struct zink_screen *screen, struct zink_shader *zs);
+void
+zink_shader_serialize_blob(nir_shader *nir, struct blob *blob);
 #endif
index 87c78f3..3c2237c 100644 (file)
@@ -1041,15 +1041,13 @@ zink_create_gfx_program(struct zink_context *ctx,
          prog->stages_present |= BITFIELD_BIT(i);
          prog->optimal_keys &= !prog->shaders[i]->non_fs.is_generated;
          prog->needs_inlining |= prog->shaders[i]->needs_inlining;
-         prog->nir[i] = nir_shader_clone(NULL, stages[i]->nir);
+         prog->nir[i] = zink_shader_deserialize(screen, stages[i]);
       }
    }
    if (stages[MESA_SHADER_TESS_EVAL] && !stages[MESA_SHADER_TESS_CTRL]) {
-      nir_shader *nir;
       prog->shaders[MESA_SHADER_TESS_EVAL]->non_fs.generated_tcs =
       prog->shaders[MESA_SHADER_TESS_CTRL] =
-        zink_shader_tcs_create(screen, prog->nir[MESA_SHADER_VERTEX], vertices_per_patch, &nir);
-      prog->nir[MESA_SHADER_TESS_CTRL] = nir_shader_clone(NULL, nir);
+        zink_shader_tcs_create(screen, prog->nir[MESA_SHADER_VERTEX], vertices_per_patch, &prog->nir[MESA_SHADER_TESS_CTRL]);
       prog->stages_present |= BITFIELD_BIT(MESA_SHADER_TESS_CTRL);
    }
    prog->stages_remaining = prog->stages_present;
@@ -1259,20 +1257,15 @@ precompile_compute_job(void *data, void *gdata, int thread_index)
    comp->shader = zink_shader_create(screen, comp->nir, NULL);
    comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module);
    assert(comp->module);
-   comp->module->shader = zink_shader_compile(screen, comp->shader, comp->shader->nir, NULL, NULL);
+   comp->module->shader = zink_shader_compile(screen, comp->shader, comp->nir, NULL, NULL);
    assert(comp->module->shader);
    util_dynarray_init(&comp->shader_cache[0], comp);
    util_dynarray_init(&comp->shader_cache[1], comp);
 
-   struct blob blob = {0};
-   blob_init(&blob);
-   nir_serialize(&blob, comp->shader->nir, true);
-
    struct mesa_sha1 sha1_ctx;
    _mesa_sha1_init(&sha1_ctx);
-   _mesa_sha1_update(&sha1_ctx, blob.data, blob.size);
+   _mesa_sha1_update(&sha1_ctx, comp->shader->blob.data, comp->shader->blob.size);
    _mesa_sha1_final(&sha1_ctx, comp->base.sha1);
-   blob_finish(&blob);
 
    zink_descriptor_program_init(comp->base.ctx, &comp->base);
 
@@ -1483,7 +1476,7 @@ zink_destroy_compute_program(struct zink_screen *screen,
    assert(!comp->shader->spirv);
 
    _mesa_set_destroy(comp->shader->programs, NULL);
-   ralloc_free(comp->shader->nir);
+   ralloc_free(comp->nir);
    ralloc_free(comp->shader);
 
    destroy_shader_cache(screen, &comp->shader_cache[0]);
@@ -1897,7 +1890,9 @@ zink_create_gfx_shader_state(struct pipe_context *pctx, const struct pipe_shader
    if (nir->info.uses_bindless)
       zink_descriptors_init_bindless(zink_context(pctx));
 
-   return zink_shader_create(zink_screen(pctx->screen), nir, &shader->stream_output);
+   void *ret = zink_shader_create(zink_screen(pctx->screen), nir, &shader->stream_output);
+   ralloc_free(nir);
+   return ret;
 }
 
 static void
@@ -2305,16 +2300,17 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
           (ctx->gfx_stages[MESA_SHADER_GEOMETRY]->info.gs.input_primitive != ctx->gfx_pipeline_state.gfx_prim_mode)) {
 
          if (!ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode][zink_prim_type]) {
+            nir_shader *prev_stage = zink_shader_deserialize(screen, ctx->gfx_stages[prev_vertex_stage]);
             nir_shader *nir;
             if (lower_filled_quad) {
                nir = zink_create_quads_emulation_gs(
                   &screen->nir_options,
-                  ctx->gfx_stages[prev_vertex_stage]->nir,
+                  prev_stage,
                   ZINK_INLINE_VAL_PV_LAST_VERT * 4);
             } else {
                nir = nir_create_passthrough_gs(
                   &screen->nir_options,
-                  ctx->gfx_stages[prev_vertex_stage]->nir,
+                  prev_stage,
                   ctx->gfx_pipeline_state.gfx_prim_mode,
                   ZINK_INLINE_VAL_FLAT_MASK * sizeof(uint32_t),
                   ZINK_INLINE_VAL_PV_LAST_VERT * sizeof(uint32_t),
@@ -2324,6 +2320,7 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
 
             zink_add_inline_uniform(nir, ZINK_INLINE_VAL_FLAT_MASK);
             zink_add_inline_uniform(nir, ZINK_INLINE_VAL_PV_LAST_VERT);
+            ralloc_free(prev_stage);
             struct zink_shader *shader = zink_shader_create(screen, nir, &ctx->gfx_stages[prev_vertex_stage]->sinfo.so_info);
             shader->needs_inlining = true;
             ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode][zink_prim_type] = shader;
index 4057f77..1e90537 100644 (file)
@@ -733,7 +733,7 @@ enum zink_rast_prim {
 struct zink_shader {
    struct util_live_shader base;
    uint32_t hash;
-   struct nir_shader *nir;
+   struct blob blob;
    struct shader_info info;
 
    struct zink_shader_info sinfo;