nir/lower_shader_calls: add ability to force remat of instructions
authorLionel Landwerlin <lionel.g.landwerlin@intel.com>
Thu, 19 Jan 2023 09:54:10 +0000 (11:54 +0200)
committerMarge Bot <emma+marge@anholt.net>
Tue, 30 May 2023 06:36:36 +0000 (06:36 +0000)
Some instruction we would like to keep around because they carry
additional information in their indices.

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21645>

src/compiler/nir/nir.h
src/compiler/nir/nir_lower_shader_calls.c

index 877fb1f..0bb04ab 100644 (file)
@@ -5092,6 +5092,8 @@ typedef struct {
 
 bool nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options);
 
+typedef bool (*nir_lower_shader_calls_should_remat_func)(nir_instr *instr, void *data);
+
 typedef struct nir_lower_shader_calls_options {
    /* Address format used for load/store operations on the call stack. */
    nir_address_format address_format;
@@ -5112,6 +5114,17 @@ typedef struct nir_lower_shader_calls_options {
 
    /* Data passed to vectorizer_callback */
    void *vectorizer_data;
+
+   /* If this function pointer is not NULL, lower_shader_calls will call this
+    * function on instructions that require spill/fill/rematerialization of
+    * their value. If this function returns true, lower_shader_calls will
+    * ensure that the instruction is rematerialized, adding the sources of the
+    * instruction to be spilled/filled.
+    */
+   nir_lower_shader_calls_should_remat_func should_remat_callback;
+
+   /* Data passed to should_remat_callback */
+   void *should_remat_data;
 } nir_lower_shader_calls_options;
 
 bool
index e93a50c..c30104f 100644 (file)
@@ -420,9 +420,19 @@ spill_fill(nir_builder *before, nir_builder *after, nir_ssa_def *def,
                          .align_mul = MIN2(comp_size, stack_alignment));
 }
 
+static bool
+add_src_to_call_live_bitset(nir_src *src, void *state)
+{
+   BITSET_WORD *call_live = state;
+
+   assert(src->is_ssa);
+   BITSET_SET(call_live, src->ssa->index);
+   return true;
+}
+
 static void
 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
-                                      unsigned stack_alignment)
+                                      const nir_lower_shader_calls_options *options)
 {
    /* TODO: If a SSA def is filled more than once, we probably want to just
     *       spill it at the LCM of the fill sites so we avoid unnecessary
@@ -505,6 +515,47 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
       }
    }
 
+   /* If a should_remat_callback is given, call it on each of the live values
+    * for each call site. If it returns true we need to rematerialize that
+    * instruction (instead of spill/fill). Therefore we need to add the
+    * sources as live values so that we can rematerialize on top of those
+    * spilled/filled sources.
+    */
+   if (options->should_remat_callback) {
+      BITSET_WORD **updated_call_live =
+         rzalloc_array(mem_ctx, BITSET_WORD *, num_calls);
+
+      nir_foreach_block(block, impl) {
+         nir_foreach_instr(instr, block) {
+            nir_ssa_def *def = nir_instr_ssa_def(instr);
+            if (def == NULL)
+               continue;
+
+            for (unsigned c = 0; c < num_calls; c++) {
+               if (!BITSET_TEST(call_live[c], def->index))
+                  continue;
+
+               if (!options->should_remat_callback(def->parent_instr,
+                                                   options->should_remat_data))
+                  continue;
+
+               if (updated_call_live[c] == NULL) {
+                  const unsigned bitset_words = BITSET_WORDS(impl->ssa_alloc);
+                  updated_call_live[c] = ralloc_array(mem_ctx, BITSET_WORD, bitset_words);
+                  memcpy(updated_call_live[c], call_live[c], bitset_words * sizeof(BITSET_WORD));
+               }
+
+               nir_foreach_src(instr, add_src_to_call_live_bitset, updated_call_live[c]);
+            }
+         }
+      }
+
+      for (unsigned c = 0; c < num_calls; c++) {
+         if (updated_call_live[c] != NULL)
+            call_live[c] = updated_call_live[c];
+      }
+   }
+
    nir_builder before, after;
    nir_builder_init(&before, impl);
    nir_builder_init(&after, impl);
@@ -583,7 +634,7 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
 
                   new_def = spill_fill(&before, &after, def,
                                        index, call_idx,
-                                       offset, stack_alignment);
+                                       offset, options->stack_alignment);
 
                   if (is_bool)
                      new_def = nir_b2b1(&after, new_def);
@@ -611,7 +662,7 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
 
          nir_builder *b = &before;
 
-         offset = ALIGN(offset, stack_alignment);
+         offset = ALIGN(offset, options->stack_alignment);
          max_scratch_size = MAX2(max_scratch_size, offset);
 
          /* First thing on the called shader's stack is the resume address
@@ -1945,7 +1996,7 @@ nir_lower_shader_calls(nir_shader *shader,
    unsigned start_call_scratch = shader->scratch_size;
 
    NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
-              num_calls, options->stack_alignment);
+              num_calls, options);
 
    NIR_PASS_V(shader, nir_opt_remove_phis);