From e9a5da2f4bdfd4ad4ee09bc3f6c9640e4acada13 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Fri, 21 Jul 2023 12:53:49 -0400 Subject: [PATCH] nir: add a filter cb to lower_io_to_scalar MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit this is useful for drivers that want to do selective scalarization of io Reviewed-by: Timur Kristóf Part-of: --- src/amd/vulkan/radv_pipeline.c | 2 +- src/asahi/compiler/agx_compile.c | 2 +- src/broadcom/compiler/vir.c | 8 +++--- src/compiler/nir/nir.h | 2 +- src/compiler/nir/nir_lower_io_to_scalar.c | 42 +++++++++++++++++++--------- src/freedreno/ir3/ir3_nir.c | 4 +-- src/gallium/drivers/lima/lima_program.c | 2 +- src/gallium/drivers/radeonsi/si_shader.c | 2 +- src/gallium/drivers/radeonsi/si_shader_nir.c | 2 +- src/gallium/drivers/vc4/vc4_program.c | 6 ++-- src/gallium/drivers/zink/zink_compiler.c | 6 ++-- src/imagination/rogue/rogue_nir.c | 6 ++-- src/microsoft/compiler/nir_to_dxil.c | 2 +- 13 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index 0e145be..26a89e4 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -697,7 +697,7 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_pipeline_key }); if (radv_use_llvm_for_stage(device, stage->stage)) - NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global); + NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global, NULL, NULL); NIR_PASS(_, stage->nir, ac_nir_lower_global_access); NIR_PASS_V(stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level, radv_select_hw_stage(&stage->info, gfx_level), diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index c680555..83d830b 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -2685,7 +2685,7 @@ agx_compile_shader_nir(nir_shader *nir, struct agx_shader_key *key, * transform feedback programs will use vector output. */ if (nir->info.stage == MESA_SHADER_VERTEX) - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); out->push_count = key->reserved_preamble; agx_optimize_nir(nir, &out->push_count); diff --git a/src/broadcom/compiler/vir.c b/src/broadcom/compiler/vir.c index dba3271..36ba386 100644 --- a/src/broadcom/compiler/vir.c +++ b/src/broadcom/compiler/vir.c @@ -1040,7 +1040,7 @@ v3d_nir_lower_gs_late(struct v3d_compile *c) } /* Note: GS output scalarizing must happen after nir_lower_clip_gs. */ - NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); } static void @@ -1050,11 +1050,11 @@ v3d_nir_lower_vs_late(struct v3d_compile *c) NIR_PASS(_, c->s, nir_lower_clip_vs, c->key->ucp_enables, false, false, NULL); NIR_PASS_V(c->s, nir_lower_io_to_scalar, - nir_var_shader_out); + nir_var_shader_out, NULL, NULL); } /* Note: VS output scalarizing must happen after nir_lower_clip_vs. */ - NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); } static void @@ -1070,7 +1070,7 @@ v3d_nir_lower_fs_late(struct v3d_compile *c) if (c->key->ucp_enables) NIR_PASS(_, c->s, nir_lower_clip_fs, c->key->ucp_enables, true); - NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in); + NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL); } static uint32_t diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 15997fc..f027545 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5242,7 +5242,7 @@ bool nir_lower_phis_to_scalar(nir_shader *shader, bool lower_all); void nir_lower_io_arrays_to_elements(nir_shader *producer, nir_shader *consumer); void nir_lower_io_arrays_to_elements_no_indirects(nir_shader *shader, bool outputs_only); -bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask); +bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data); bool nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask); bool nir_lower_io_to_vector(nir_shader *shader, nir_variable_mode mask); bool nir_vectorize_tess_levels(nir_shader *shader); diff --git a/src/compiler/nir/nir_lower_io_to_scalar.c b/src/compiler/nir/nir_lower_io_to_scalar.c index a777e73..b36cfcf 100644 --- a/src/compiler/nir/nir_lower_io_to_scalar.c +++ b/src/compiler/nir/nir_lower_io_to_scalar.c @@ -231,10 +231,16 @@ lower_store_to_scalar(nir_builder *b, nir_intrinsic_instr *intr) nir_instr_remove(&intr->instr); } +struct scalarize_state { + nir_variable_mode mask; + nir_instr_filter_cb filter; + void *filter_data; +}; + static bool nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data) { - nir_variable_mode mask = *(nir_variable_mode *)data; + struct scalarize_state *state = data; if (instr->type != nir_instr_type_intrinsic) return false; @@ -247,36 +253,41 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data) if ((intr->intrinsic == nir_intrinsic_load_input || intr->intrinsic == nir_intrinsic_load_per_vertex_input || intr->intrinsic == nir_intrinsic_load_interpolated_input) && - (mask & nir_var_shader_in)) { + (state->mask & nir_var_shader_in) && + (!state->filter || state->filter(instr, state->filter_data))) { lower_load_input_to_scalar(b, intr); return true; } if ((intr->intrinsic == nir_intrinsic_load_output || intr->intrinsic == nir_intrinsic_load_per_vertex_output) && - (mask & nir_var_shader_out)) { + (state->mask & nir_var_shader_out) && + (!state->filter || state->filter(instr, state->filter_data))) { lower_load_input_to_scalar(b, intr); return true; } - if ((intr->intrinsic == nir_intrinsic_load_ubo && (mask & nir_var_mem_ubo)) || - (intr->intrinsic == nir_intrinsic_load_ssbo && (mask & nir_var_mem_ssbo)) || - (intr->intrinsic == nir_intrinsic_load_global && (mask & nir_var_mem_global)) || - (intr->intrinsic == nir_intrinsic_load_shared && (mask & nir_var_mem_shared))) { + if (((intr->intrinsic == nir_intrinsic_load_ubo && (state->mask & nir_var_mem_ubo)) || + (intr->intrinsic == nir_intrinsic_load_ssbo && (state->mask & nir_var_mem_ssbo)) || + (intr->intrinsic == nir_intrinsic_load_global && (state->mask & nir_var_mem_global)) || + (intr->intrinsic == nir_intrinsic_load_shared && (state->mask & nir_var_mem_shared))) && + (!state->filter || state->filter(instr, state->filter_data))) { lower_load_to_scalar(b, intr); return true; } if ((intr->intrinsic == nir_intrinsic_store_output || intr->intrinsic == nir_intrinsic_store_per_vertex_output) && - mask & nir_var_shader_out) { + state->mask & nir_var_shader_out && + (!state->filter || state->filter(instr, state->filter_data))) { lower_store_output_to_scalar(b, intr); return true; } - if ((intr->intrinsic == nir_intrinsic_store_ssbo && (mask & nir_var_mem_ssbo)) || - (intr->intrinsic == nir_intrinsic_store_global && (mask & nir_var_mem_global)) || - (intr->intrinsic == nir_intrinsic_store_shared && (mask & nir_var_mem_shared))) { + if (((intr->intrinsic == nir_intrinsic_store_ssbo && (state->mask & nir_var_mem_ssbo)) || + (intr->intrinsic == nir_intrinsic_store_global && (state->mask & nir_var_mem_global)) || + (intr->intrinsic == nir_intrinsic_store_shared && (state->mask & nir_var_mem_shared))) && + (!state->filter || state->filter(instr, state->filter_data))) { lower_store_to_scalar(b, intr); return true; } @@ -285,13 +296,18 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data) } bool -nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask) +nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data) { + struct scalarize_state state = { + mask, + filter, + filter_data + }; return nir_shader_instructions_pass(shader, nir_lower_io_to_scalar_instr, nir_metadata_block_index | nir_metadata_dominance, - &mask); + &state); } static nir_variable ** diff --git a/src/freedreno/ir3/ir3_nir.c b/src/freedreno/ir3/ir3_nir.c index 60b31aa..d2193e1 100644 --- a/src/freedreno/ir3/ir3_nir.c +++ b/src/freedreno/ir3/ir3_nir.c @@ -647,7 +647,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s) bool progress = false; - NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo); + NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo, NULL, NULL); if (so->key.has_gs || so->key.tessellation) { switch (so->type) { @@ -658,7 +658,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s) break; case MESA_SHADER_TESS_CTRL: NIR_PASS_V(s, nir_lower_io_to_scalar, - nir_var_shader_in | nir_var_shader_out); + nir_var_shader_in | nir_var_shader_out, NULL, NULL); NIR_PASS_V(s, ir3_nir_lower_tess_ctrl, so, so->key.tessellation); NIR_PASS_V(s, ir3_nir_lower_to_explicit_input, so); progress = true; diff --git a/src/gallium/drivers/lima/lima_program.c b/src/gallium/drivers/lima/lima_program.c index 5768aa1..ef5e3a2 100644 --- a/src/gallium/drivers/lima/lima_program.c +++ b/src/gallium/drivers/lima/lima_program.c @@ -122,7 +122,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s) NIR_PASS_V(s, nir_lower_load_const_to_scalar); NIR_PASS_V(s, lima_nir_lower_uniform_to_scalar); NIR_PASS_V(s, nir_lower_io_to_scalar, - nir_var_shader_in|nir_var_shader_out); + nir_var_shader_in|nir_var_shader_out, NULL, NULL); do { progress = false; diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c index f284b50..8951a59 100644 --- a/src/gallium/drivers/radeonsi/si_shader.c +++ b/src/gallium/drivers/radeonsi/si_shader.c @@ -1774,7 +1774,7 @@ static void si_lower_ngg(struct si_shader *shader, nir_shader *nir) NIR_PASS_V(nir, nir_lower_subgroups, &si_nir_subgroups_options); /* may generate some vector output store */ - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); } struct nir_shader *si_deserialize_shader(struct si_shader_selector *sel) diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index ce41c59..70fa288 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -305,7 +305,7 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir) if (nir->info.stage == MESA_SHADER_VERTEX || nir->info.stage == MESA_SHADER_TESS_EVAL || nir->info.stage == MESA_SHADER_GEOMETRY) - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); if (nir->info.stage == MESA_SHADER_GEOMETRY) { unsigned flags = nir_lower_gs_intrinsics_per_stream; diff --git a/src/gallium/drivers/vc4/vc4_program.c b/src/gallium/drivers/vc4/vc4_program.c index af74b2c..dab38c6 100644 --- a/src/gallium/drivers/vc4/vc4_program.c +++ b/src/gallium/drivers/vc4/vc4_program.c @@ -2293,7 +2293,7 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage, NIR_PASS_V(c->s, nir_lower_clip_vs, c->key->ucp_enables, false, false, NULL); NIR_PASS_V(c->s, nir_lower_io_to_scalar, - nir_var_shader_out); + nir_var_shader_out, NULL, NULL); } } @@ -2302,9 +2302,9 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage, * scalarizing must happen after nir_lower_clip_vs. */ if (c->stage == QSTAGE_FRAG) - NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in); + NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL); else - NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); NIR_PASS_V(c->s, vc4_nir_lower_io, c); NIR_PASS_V(c->s, vc4_nir_lower_txf_ms, c); diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 08c21f9..636a071 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -3706,7 +3706,7 @@ zink_shader_compile(struct zink_screen *screen, bool can_shobj, struct zink_shad } } if (screen->driconf.inline_uniforms) { - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL); NIR_PASS_V(nir, rewrite_bo_access, screen); NIR_PASS_V(nir, remove_bo_access, zs); need_optimize = true; @@ -3761,7 +3761,7 @@ zink_shader_compile_separate(struct zink_screen *screen, struct zink_shader *zs) } } if (screen->driconf.inline_uniforms) { - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL); NIR_PASS_V(nir, rewrite_bo_access, screen); NIR_PASS_V(nir, remove_bo_access, zs); } @@ -4913,7 +4913,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, NIR_PASS_V(nir, unbreak_bos, ret, needs_size); /* run in compile if there could be inlined uniforms */ if (!screen->driconf.inline_uniforms && !nir->info.num_inlinable_uniforms) { - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL); NIR_PASS_V(nir, rewrite_bo_access, screen); NIR_PASS_V(nir, remove_bo_access, ret); } diff --git a/src/imagination/rogue/rogue_nir.c b/src/imagination/rogue/rogue_nir.c index 277fe2c..ecbf454 100644 --- a/src/imagination/rogue/rogue_nir.c +++ b/src/imagination/rogue/rogue_nir.c @@ -89,7 +89,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx, /* Load inputs to scalars (single registers later). */ /* TODO: Fitrp can process multiple frag inputs at once, scalarise I/O. */ - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL); /* Optimize GL access qualifiers. */ const nir_opt_access_options opt_access_options = { @@ -102,7 +102,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx, NIR_PASS_V(nir, rogue_nir_pfo); /* Load outputs to scalars (single registers later). */ - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL); /* Lower ALU operations to scalars. */ NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL); @@ -115,7 +115,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx, nir_lower_explicit_io, nir_var_mem_ubo, spirv_options.ubo_addr_format); - NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo, NULL, NULL); NIR_PASS_V(nir, rogue_nir_lower_io); /* Algebraic opts. */ diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index ca41b21..68ccd7d 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -6518,7 +6518,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts, NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32); NIR_PASS_V(s, dxil_nir_ensure_position_writes); NIR_PASS_V(s, dxil_nir_lower_system_values); - NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out); + NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out, NULL, NULL); /* Do a round of optimization to try to vectorize loads/stores. Otherwise the addresses used for loads * might be too opaque for the pass to see that they're next to each other. */ -- 2.7.4