microsoft/compiler: Split hull (tess ctrl) shaders into main and patch constant funcs
authorJesse Natalie <jenatali@microsoft.com>
Sat, 1 Jan 2022 21:14:05 +0000 (13:14 -0800)
committerMarge Bot <emma+marge@anholt.net>
Wed, 26 Jan 2022 01:31:35 +0000 (01:31 +0000)
Reviewed-by: Boris Brezillon <boris.brezillon@collabora.com>
Reviewed-by: Bill Kristiansen <billkris@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14399>

src/microsoft/compiler/dxil_nir.h
src/microsoft/compiler/dxil_nir_tess.c [new file with mode: 0644]
src/microsoft/compiler/meson.build
src/microsoft/compiler/nir_to_dxil.c

index b032d7a..db7dae9 100644 (file)
@@ -68,6 +68,8 @@ uint64_t
 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
    uint64_t other_stage_mask);
 
+void dxil_nir_split_tess_ctrl(nir_shader *nir, nir_function **patch_const_func);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/microsoft/compiler/dxil_nir_tess.c b/src/microsoft/compiler/dxil_nir_tess.c
new file mode 100644 (file)
index 0000000..e326475
--- /dev/null
@@ -0,0 +1,272 @@
+/*
+ * Copyright © Microsoft Corporation
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "nir.h"
+#include "nir_builder.h"
+#include "nir_control_flow.h"
+
+#include "dxil_nir.h"
+
+static void
+remove_hs_intrinsics(nir_function_impl *impl)
+{
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+         nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+         if (intr->intrinsic != nir_intrinsic_store_output &&
+             intr->intrinsic != nir_intrinsic_memory_barrier_tcs_patch &&
+             intr->intrinsic != nir_intrinsic_control_barrier)
+            continue;
+         nir_instr_remove(instr);
+      }
+   }
+   nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
+}
+
+static void
+add_instr_and_srcs_to_set(struct set *instr_set, nir_instr *instr);
+
+static bool
+add_srcs_to_set(nir_src *src, void *state)
+{
+   assert(src->is_ssa);
+   add_instr_and_srcs_to_set(state, src->ssa->parent_instr);
+   return true;
+}
+
+static void
+add_instr_and_srcs_to_set(struct set *instr_set, nir_instr *instr)
+{
+   bool was_already_found = false;
+   _mesa_set_search_or_add(instr_set, instr, &was_already_found);
+   if (!was_already_found)
+      nir_foreach_src(instr, add_srcs_to_set, instr_set);
+}
+
+static void
+prune_patch_function_to_intrinsic_and_srcs(nir_function_impl *impl)
+{
+   struct set *instr_set = _mesa_pointer_set_create(NULL);
+
+   /* Do this in two phases:
+    * 1. Find all instructions that contribute to a store_output and add them to
+    *    the set. Also, add instructions that contribute to control flow.
+    * 2. Erase every instruction that isn't in the set
+    */
+   nir_foreach_block(block, impl) {
+      nir_if *following_if = nir_block_get_following_if(block);
+      if (following_if) {
+         assert(following_if->condition.is_ssa);
+         add_instr_and_srcs_to_set(instr_set, following_if->condition.ssa->parent_instr);
+      }
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type == nir_instr_type_intrinsic) {
+            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+            if (intr->intrinsic != nir_intrinsic_store_output &&
+                intr->intrinsic != nir_intrinsic_memory_barrier_tcs_patch)
+               continue;
+         } else if (instr->type != nir_instr_type_jump)
+            continue;
+         add_instr_and_srcs_to_set(instr_set, instr);
+      }
+   }
+
+   nir_foreach_block_reverse(block, impl) {
+      nir_foreach_instr_reverse_safe(instr, block) {
+         struct set_entry *entry = _mesa_set_search(instr_set, instr);
+         if (!entry)
+            nir_instr_remove(instr);
+      }
+   }
+
+   _mesa_set_destroy(instr_set, NULL);
+}
+
+static nir_cursor
+get_cursor_for_instr_without_cf(nir_instr *instr)
+{
+   nir_block *block = instr->block;
+   if (block->cf_node.parent->type == nir_cf_node_function)
+      return nir_before_instr(instr);
+
+   do {
+      block = nir_cf_node_as_block(nir_cf_node_prev(block->cf_node.parent));
+   } while (block->cf_node.parent->type != nir_cf_node_function);
+   return nir_after_block_before_jump(block);
+}
+
+struct tcs_patch_loop_state {
+   nir_ssa_def *deref, *count;
+   nir_cursor begin_cursor, end_cursor, insert_cursor;
+   nir_loop *loop;
+};
+
+static void
+start_tcs_loop(nir_builder *b, struct tcs_patch_loop_state *state, nir_deref_instr *loop_var_deref)
+{
+   if (!loop_var_deref)
+      return;
+
+   nir_store_deref(b, loop_var_deref, nir_imm_int(b, 0), 1);
+   state->loop = nir_push_loop(b);
+   state->count = nir_load_deref(b, loop_var_deref);
+   nir_push_if(b, nir_ige(b, state->count, nir_imm_int(b, b->impl->function->shader->info.tess.tcs_vertices_out)));
+   nir_jump(b, nir_jump_break);
+   nir_pop_if(b, NULL);
+   state->insert_cursor = b->cursor;
+   nir_store_deref(b, loop_var_deref, nir_iadd_imm(b, state->count, 1), 1);
+   nir_pop_loop(b, state->loop);
+}
+
+static void
+end_tcs_loop(nir_builder *b, struct tcs_patch_loop_state *state)
+{
+   if (!state->loop)
+      return;
+
+   nir_cf_list extracted;
+   nir_cf_extract(&extracted, state->begin_cursor, state->end_cursor);
+   nir_cf_reinsert(&extracted, state->insert_cursor);
+
+   *state = (struct tcs_patch_loop_state ){ 0 };
+}
+
+/* In HLSL/DXIL, the hull (tesselation control) shader is split into two:
+ * 1. The main hull shader, which runs once per output control point.
+ * 2. A patch constant function, which runs once overall.
+ * In GLSL/NIR, these are combined. Each invocation must write to the output
+ * array with a constant gl_InvocationID, which is (apparently) lowered to an
+ * if/else ladder in nir. Each invocation must write the same value to patch
+ * constants - or else undefined behavior strikes. NIR uses store_output to
+ * write the patch constants, and store_per_vertex_output to write the control
+ * point values.
+ * 
+ * We clone the NIR function to produce 2: one with the store_output intrinsics
+ * removed, which becomes the main shader (only writes control points), and one
+ * with everything that doesn't contribute to store_output removed, which becomes
+ * the patch constant function.
+ * 
+ * For the patch constant function, if the expressions rely on gl_InvocationID,
+ * then we need to run the resulting logic in a loop, using the loop counter to
+ * replace gl_InvocationID. This loop can be terminated when a barrier is hit. If
+ * gl_InvocationID is used again after the barrier, then another loop needs to begin.
+ */
+void
+dxil_nir_split_tess_ctrl(nir_shader *nir, nir_function **patch_const_func)
+{
+   assert(nir->info.stage == MESA_SHADER_TESS_CTRL);
+   assert(exec_list_length(&nir->functions) == 1);
+   nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
+
+   *patch_const_func = nir_function_create(nir, "PatchConstantFunc");
+   nir_function_impl *patch_const_func_impl = nir_function_impl_clone(nir, entrypoint);
+   (*patch_const_func)->impl = patch_const_func_impl;
+   patch_const_func_impl->function = *patch_const_func;
+
+   remove_hs_intrinsics(entrypoint);
+   prune_patch_function_to_intrinsic_and_srcs(patch_const_func_impl);
+
+   /* Kill dead references to the invocation ID from the patch const func so we don't
+    * insert unnecessarily loops
+    */
+   while (nir_opt_dead_cf(nir) | nir_opt_dce(nir));
+
+   /* Now, the patch constant function needs to be split into blocks and loops.
+    * The series of instructions up to the first block containing a load_invocation_id
+    * will run sequentially. Then a loop is inserted so load_invocation_id will load the
+    * loop counter. This loop continues until a barrier is reached, when the loop
+    * is closed and the process begins again.
+    * 
+    * First, sink load_invocation_id so that it's present on both sides of barriers.
+    * Each use gets a unique load of the invocation ID.
+    */
+   nir_builder b;
+   nir_builder_init(&b, patch_const_func_impl);
+   nir_foreach_block(block, patch_const_func_impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+         nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+         if (intr->intrinsic != nir_intrinsic_load_invocation_id ||
+             list_length(&intr->dest.ssa.uses) +
+             list_length(&intr->dest.ssa.if_uses) <= 1)
+            continue;
+         nir_foreach_use_safe(src, &intr->dest.ssa) {
+            b.cursor = nir_before_src(src, false);
+            nir_instr_rewrite_src_ssa(src->parent_instr, src, nir_load_invocation_id(&b));
+         }
+         nir_foreach_if_use_safe(src, &intr->dest.ssa) {
+            b.cursor = nir_before_src(src, true);
+            nir_if_rewrite_condition_ssa(src->parent_if, src, nir_load_invocation_id(&b));
+         }
+         nir_instr_remove(instr);
+      }
+   }
+
+   /* Now replace those invocation ID loads with loads of a local variable that's used as a loop counter */
+   nir_variable *loop_var = NULL;
+   nir_deref_instr *loop_var_deref = NULL;
+   struct tcs_patch_loop_state state = { 0 };
+   nir_foreach_block_safe(block, patch_const_func_impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+         nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+         switch (intr->intrinsic) {
+         case nir_intrinsic_load_invocation_id: {
+            if (!loop_var) {
+               loop_var = nir_local_variable_create(patch_const_func_impl, glsl_int_type(), "PatchConstInvocId");
+               b.cursor = nir_before_cf_list(&patch_const_func_impl->body);
+               loop_var_deref = nir_build_deref_var(&b, loop_var);
+            }
+            if (!state.loop) {
+               b.cursor = state.begin_cursor = get_cursor_for_instr_without_cf(instr);
+               start_tcs_loop(&b, &state, loop_var_deref);
+            }
+            nir_ssa_def_rewrite_uses(&intr->dest.ssa, state.count);
+            break;
+         }
+         case nir_intrinsic_memory_barrier_tcs_patch:
+            /* The GL tessellation spec says:
+             * The barrier() function may only be called inside the main entry point of the tessellation control shader
+             * and may not be called in potentially divergent flow control.  In particular, barrier() may not be called
+             * inside a switch statement, in either sub-statement of an if statement, inside a do, for, or while loop,
+             * or at any point after a return statement in the function main().
+             * 
+             * Therefore, we should be at function-level control flow.
+             */
+            assert(nir_cursors_equal(nir_before_instr(instr), get_cursor_for_instr_without_cf(instr)));
+            state.end_cursor = nir_before_instr(instr);
+            end_tcs_loop(&b, &state);
+            nir_instr_remove(instr);
+            break;
+         default:
+            break;
+         }
+      }
+   }
+   state.end_cursor = nir_after_block_before_jump(nir_impl_last_block(patch_const_func_impl));
+   end_tcs_loop(&b, &state);
+}
index 822694c..78e4792 100644 (file)
@@ -29,6 +29,7 @@ files_libdxil_compiler = files(
   'dxil_nir.c',
   'dxil_nir_lower_int_samplers.c',
   'dxil_signature.c',
+  'dxil_nir_tess.c',
   'nir_to_dxil.c',
 )
 
index 6ba3390..45b9e2c 100644 (file)
@@ -463,6 +463,8 @@ struct ntd_context {
    nir_variable *ps_front_face;
    nir_variable *system_value[SYSTEM_VALUE_MAX];
 
+   nir_function *tess_ctrl_patch_constant_func;
+
    struct dxil_func_def *main_func_def;
 };
 
@@ -5318,6 +5320,9 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
    NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, (nir_lower_io_options)0);
    NIR_PASS_V(s, dxil_nir_lower_system_values);
 
+   if (ctx->mod.shader_kind == DXIL_HULL_SHADER)
+      NIR_PASS_V(s, dxil_nir_split_tess_ctrl, &ctx->tess_ctrl_patch_constant_func);
+
    optimize_nir(s, opts);
 
    NIR_PASS_V(s, nir_remove_dead_variables,