nir: add a filter cb to lower_io_to_scalar
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Fri, 21 Jul 2023 16:53:49 +0000 (12:53 -0400)
committerMarge Bot <emma+marge@anholt.net>
Fri, 11 Aug 2023 09:02:53 +0000 (09:02 +0000)
this is useful for drivers that want to do selective scalarization
of io

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24565>

13 files changed:
src/amd/vulkan/radv_pipeline.c
src/asahi/compiler/agx_compile.c
src/broadcom/compiler/vir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_io_to_scalar.c
src/freedreno/ir3/ir3_nir.c
src/gallium/drivers/lima/lima_program.c
src/gallium/drivers/radeonsi/si_shader.c
src/gallium/drivers/radeonsi/si_shader_nir.c
src/gallium/drivers/vc4/vc4_program.c
src/gallium/drivers/zink/zink_compiler.c
src/imagination/rogue/rogue_nir.c
src/microsoft/compiler/nir_to_dxil.c

index 0e145be..26a89e4 100644 (file)
@@ -697,7 +697,7 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_pipeline_key
             });
 
    if (radv_use_llvm_for_stage(device, stage->stage))
-      NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global);
+      NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global, NULL, NULL);
 
    NIR_PASS(_, stage->nir, ac_nir_lower_global_access);
    NIR_PASS_V(stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level, radv_select_hw_stage(&stage->info, gfx_level),
index c680555..83d830b 100644 (file)
@@ -2685,7 +2685,7 @@ agx_compile_shader_nir(nir_shader *nir, struct agx_shader_key *key,
     * transform feedback programs will use vector output.
     */
    if (nir->info.stage == MESA_SHADER_VERTEX)
-      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
+      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 
    out->push_count = key->reserved_preamble;
    agx_optimize_nir(nir, &out->push_count);
index dba3271..36ba386 100644 (file)
@@ -1040,7 +1040,7 @@ v3d_nir_lower_gs_late(struct v3d_compile *c)
         }
 
         /* Note: GS output scalarizing must happen after nir_lower_clip_gs. */
-        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
+        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 }
 
 static void
@@ -1050,11 +1050,11 @@ v3d_nir_lower_vs_late(struct v3d_compile *c)
                 NIR_PASS(_, c->s, nir_lower_clip_vs, c->key->ucp_enables,
                          false, false, NULL);
                 NIR_PASS_V(c->s, nir_lower_io_to_scalar,
-                           nir_var_shader_out);
+                           nir_var_shader_out, NULL, NULL);
         }
 
         /* Note: VS output scalarizing must happen after nir_lower_clip_vs. */
-        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
+        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 }
 
 static void
@@ -1070,7 +1070,7 @@ v3d_nir_lower_fs_late(struct v3d_compile *c)
         if (c->key->ucp_enables)
                 NIR_PASS(_, c->s, nir_lower_clip_fs, c->key->ucp_enables, true);
 
-        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in);
+        NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
 }
 
 static uint32_t
index 15997fc..f027545 100644 (file)
@@ -5242,7 +5242,7 @@ bool nir_lower_phis_to_scalar(nir_shader *shader, bool lower_all);
 void nir_lower_io_arrays_to_elements(nir_shader *producer, nir_shader *consumer);
 void nir_lower_io_arrays_to_elements_no_indirects(nir_shader *shader,
                                                   bool outputs_only);
-bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask);
+bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data);
 bool nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask);
 bool nir_lower_io_to_vector(nir_shader *shader, nir_variable_mode mask);
 bool nir_vectorize_tess_levels(nir_shader *shader);
index a777e73..b36cfcf 100644 (file)
@@ -231,10 +231,16 @@ lower_store_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
    nir_instr_remove(&intr->instr);
 }
 
+struct scalarize_state {
+   nir_variable_mode mask;
+   nir_instr_filter_cb filter;
+   void *filter_data;
+};
+
 static bool
 nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
 {
-   nir_variable_mode mask = *(nir_variable_mode *)data;
+   struct scalarize_state *state = data;
 
    if (instr->type != nir_instr_type_intrinsic)
       return false;
@@ -247,36 +253,41 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
    if ((intr->intrinsic == nir_intrinsic_load_input ||
         intr->intrinsic == nir_intrinsic_load_per_vertex_input ||
         intr->intrinsic == nir_intrinsic_load_interpolated_input) &&
-       (mask & nir_var_shader_in)) {
+       (state->mask & nir_var_shader_in) &&
+       (!state->filter || state->filter(instr, state->filter_data))) {
       lower_load_input_to_scalar(b, intr);
       return true;
    }
 
    if ((intr->intrinsic == nir_intrinsic_load_output ||
         intr->intrinsic == nir_intrinsic_load_per_vertex_output) &&
-      (mask & nir_var_shader_out)) {
+      (state->mask & nir_var_shader_out) &&
+      (!state->filter || state->filter(instr, state->filter_data))) {
       lower_load_input_to_scalar(b, intr);
       return true;
    }
 
-   if ((intr->intrinsic == nir_intrinsic_load_ubo && (mask & nir_var_mem_ubo)) ||
-       (intr->intrinsic == nir_intrinsic_load_ssbo && (mask & nir_var_mem_ssbo)) ||
-       (intr->intrinsic == nir_intrinsic_load_global && (mask & nir_var_mem_global)) ||
-       (intr->intrinsic == nir_intrinsic_load_shared && (mask & nir_var_mem_shared))) {
+   if (((intr->intrinsic == nir_intrinsic_load_ubo && (state->mask & nir_var_mem_ubo)) ||
+        (intr->intrinsic == nir_intrinsic_load_ssbo && (state->mask & nir_var_mem_ssbo)) ||
+        (intr->intrinsic == nir_intrinsic_load_global && (state->mask & nir_var_mem_global)) ||
+        (intr->intrinsic == nir_intrinsic_load_shared && (state->mask & nir_var_mem_shared))) &&
+       (!state->filter || state->filter(instr, state->filter_data))) {
       lower_load_to_scalar(b, intr);
       return true;
    }
 
    if ((intr->intrinsic == nir_intrinsic_store_output ||
         intr->intrinsic == nir_intrinsic_store_per_vertex_output) &&
-       mask & nir_var_shader_out) {
+       state->mask & nir_var_shader_out &&
+       (!state->filter || state->filter(instr, state->filter_data))) {
       lower_store_output_to_scalar(b, intr);
       return true;
    }
 
-   if ((intr->intrinsic == nir_intrinsic_store_ssbo && (mask & nir_var_mem_ssbo)) ||
-       (intr->intrinsic == nir_intrinsic_store_global && (mask & nir_var_mem_global)) ||
-       (intr->intrinsic == nir_intrinsic_store_shared && (mask & nir_var_mem_shared))) {
+   if (((intr->intrinsic == nir_intrinsic_store_ssbo && (state->mask & nir_var_mem_ssbo)) ||
+        (intr->intrinsic == nir_intrinsic_store_global && (state->mask & nir_var_mem_global)) ||
+        (intr->intrinsic == nir_intrinsic_store_shared && (state->mask & nir_var_mem_shared))) &&
+       (!state->filter || state->filter(instr, state->filter_data))) {
       lower_store_to_scalar(b, intr);
       return true;
    }
@@ -285,13 +296,18 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
 }
 
 bool
-nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
+nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data)
 {
+   struct scalarize_state state = {
+      mask,
+      filter,
+      filter_data
+   };
    return nir_shader_instructions_pass(shader,
                                        nir_lower_io_to_scalar_instr,
                                        nir_metadata_block_index |
                                        nir_metadata_dominance,
-                                       &mask);
+                                       &state);
 }
 
 static nir_variable **
index 60b31aa..d2193e1 100644 (file)
@@ -647,7 +647,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s)
 
    bool progress = false;
 
-   NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo);
+   NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo, NULL, NULL);
 
    if (so->key.has_gs || so->key.tessellation) {
       switch (so->type) {
@@ -658,7 +658,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s)
          break;
       case MESA_SHADER_TESS_CTRL:
          NIR_PASS_V(s, nir_lower_io_to_scalar,
-                     nir_var_shader_in | nir_var_shader_out);
+                     nir_var_shader_in | nir_var_shader_out, NULL, NULL);
          NIR_PASS_V(s, ir3_nir_lower_tess_ctrl, so, so->key.tessellation);
          NIR_PASS_V(s, ir3_nir_lower_to_explicit_input, so);
          progress = true;
index 5768aa1..ef5e3a2 100644 (file)
@@ -122,7 +122,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
    NIR_PASS_V(s, nir_lower_load_const_to_scalar);
    NIR_PASS_V(s, lima_nir_lower_uniform_to_scalar);
    NIR_PASS_V(s, nir_lower_io_to_scalar,
-              nir_var_shader_in|nir_var_shader_out);
+              nir_var_shader_in|nir_var_shader_out, NULL, NULL);
 
    do {
       progress = false;
index f284b50..8951a59 100644 (file)
@@ -1774,7 +1774,7 @@ static void si_lower_ngg(struct si_shader *shader, nir_shader *nir)
    NIR_PASS_V(nir, nir_lower_subgroups, &si_nir_subgroups_options);
 
    /* may generate some vector output store */
-   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
+   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 }
 
 struct nir_shader *si_deserialize_shader(struct si_shader_selector *sel)
index ce41c59..70fa288 100644 (file)
@@ -305,7 +305,7 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir)
    if (nir->info.stage == MESA_SHADER_VERTEX ||
        nir->info.stage == MESA_SHADER_TESS_EVAL ||
        nir->info.stage == MESA_SHADER_GEOMETRY)
-      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
+      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 
    if (nir->info.stage == MESA_SHADER_GEOMETRY) {
       unsigned flags = nir_lower_gs_intrinsics_per_stream;
index af74b2c..dab38c6 100644 (file)
@@ -2293,7 +2293,7 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage,
                         NIR_PASS_V(c->s, nir_lower_clip_vs,
                                    c->key->ucp_enables, false, false, NULL);
                         NIR_PASS_V(c->s, nir_lower_io_to_scalar,
-                                   nir_var_shader_out);
+                                   nir_var_shader_out, NULL, NULL);
                 }
         }
 
@@ -2302,9 +2302,9 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage,
          * scalarizing must happen after nir_lower_clip_vs.
          */
         if (c->stage == QSTAGE_FRAG)
-                NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in);
+                NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
         else
-                NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
+                NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 
         NIR_PASS_V(c->s, vc4_nir_lower_io, c);
         NIR_PASS_V(c->s, vc4_nir_lower_txf_ms, c);
index 08c21f9..636a071 100644 (file)
@@ -3706,7 +3706,7 @@ zink_shader_compile(struct zink_screen *screen, bool can_shobj, struct zink_shad
       }
    }
    if (screen->driconf.inline_uniforms) {
-      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
+      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
       NIR_PASS_V(nir, rewrite_bo_access, screen);
       NIR_PASS_V(nir, remove_bo_access, zs);
       need_optimize = true;
@@ -3761,7 +3761,7 @@ zink_shader_compile_separate(struct zink_screen *screen, struct zink_shader *zs)
       }
    }
    if (screen->driconf.inline_uniforms) {
-      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
+      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
       NIR_PASS_V(nir, rewrite_bo_access, screen);
       NIR_PASS_V(nir, remove_bo_access, zs);
    }
@@ -4913,7 +4913,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
    NIR_PASS_V(nir, unbreak_bos, ret, needs_size);
    /* run in compile if there could be inlined uniforms */
    if (!screen->driconf.inline_uniforms && !nir->info.num_inlinable_uniforms) {
-      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
+      NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
       NIR_PASS_V(nir, rewrite_bo_access, screen);
       NIR_PASS_V(nir, remove_bo_access, ret);
    }
index 277fe2c..ecbf454 100644 (file)
@@ -89,7 +89,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
 
    /* Load inputs to scalars (single registers later). */
    /* TODO: Fitrp can process multiple frag inputs at once, scalarise I/O. */
-   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in);
+   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
 
    /* Optimize GL access qualifiers. */
    const nir_opt_access_options opt_access_options = {
@@ -102,7 +102,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
       NIR_PASS_V(nir, rogue_nir_pfo);
 
    /* Load outputs to scalars (single registers later). */
-   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
+   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
 
    /* Lower ALU operations to scalars. */
    NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
@@ -115,7 +115,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
               nir_lower_explicit_io,
               nir_var_mem_ubo,
               spirv_options.ubo_addr_format);
-   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo);
+   NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo, NULL, NULL);
    NIR_PASS_V(nir, rogue_nir_lower_io);
 
    /* Algebraic opts. */
index ca41b21..68ccd7d 100644 (file)
@@ -6518,7 +6518,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
    NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);
    NIR_PASS_V(s, dxil_nir_ensure_position_writes);
    NIR_PASS_V(s, dxil_nir_lower_system_values);
-   NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out);
+   NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out, NULL, NULL);
 
    /* Do a round of optimization to try to vectorize loads/stores. Otherwise the addresses used for loads
     * might be too opaque for the pass to see that they're next to each other. */