microsoft/compiler: Add lowering passes for basic subgroup vars
authorJesse Natalie <jenatali@microsoft.com>
Wed, 18 Jan 2023 22:13:41 +0000 (14:13 -0800)
committerMarge Bot <emma+marge@anholt.net>
Fri, 20 Jan 2023 18:50:57 +0000 (18:50 +0000)
DXIL doesn't have a "subgroup ID" or "num subgroups" construct,
so add lowering to construct them. Subgroup ID is done using
once-per-subgroup atomics on a workgroup-shared variable, and
then broadcasting that (using read_first_invocation) to the other
threads. Num subgroups is just a division with the workgroup size.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20777>

src/microsoft/compiler/dxil_nir.c
src/microsoft/compiler/dxil_nir.h

index 27a4ce4..364ab3a 100644 (file)
@@ -2074,3 +2074,76 @@ dxil_nir_lower_sample_pos(nir_shader *s)
 {
    return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
 }
+
+static bool
+lower_subgroup_id(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+   if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
+      return false;
+
+   nir_ssa_def **subgroup_id = (nir_ssa_def **)data;
+   if (*subgroup_id == NULL) {
+      nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
+      nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
+      b->cursor = nir_before_block(nir_start_block(b->impl));
+      nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
+
+      nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
+      nir_ssa_def *tid = nir_load_local_invocation_index(b);
+      nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
+      nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
+      nir_pop_if(b, nif);
+
+      nir_scoped_memory_barrier(b, NIR_SCOPE_WORKGROUP, NIR_MEMORY_ACQ_REL, nir_var_mem_shared);
+
+      nif = nir_push_if(b, nir_elect(b, 1));
+      nir_ssa_def *subgroup_id_first_thread = nir_deref_atomic_add(b, 32, &counter_deref->dest.ssa, nir_imm_int(b, 1));
+      nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
+      nir_pop_if(b, nif);
+
+      nir_ssa_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
+      *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
+   }
+   nir_ssa_def_rewrite_uses(&intr->dest.ssa, *subgroup_id);
+   return true;
+}
+
+bool
+dxil_nir_lower_subgroup_id(nir_shader *s)
+{
+   nir_ssa_def *subgroup_id = NULL;
+   return nir_shader_instructions_pass(s, lower_subgroup_id, nir_metadata_none, &subgroup_id);
+}
+
+static bool
+lower_num_subgroups(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+   if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
+      return false;
+
+   b->cursor = nir_before_instr(instr);
+   nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
+   nir_ssa_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
+   nir_ssa_def *workgroup_size_vec = nir_load_workgroup_size(b);
+   nir_ssa_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
+                                             nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
+                                                         nir_channel(b, workgroup_size_vec, 2)));
+   nir_ssa_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
+   nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
+   return true;
+}
+
+bool
+dxil_nir_lower_num_subgroups(nir_shader *s)
+{
+   return nir_shader_instructions_pass(s, lower_num_subgroups,
+                                       nir_metadata_block_index |
+                                       nir_metadata_dominance |
+                                       nir_metadata_loop_analysis, NULL);
+}
index 57baba3..adeea6d 100644 (file)
@@ -76,6 +76,8 @@ bool dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mas
 bool dxil_nir_lower_discard_and_terminate(nir_shader* s);
 bool dxil_nir_ensure_position_writes(nir_shader *s);
 bool dxil_nir_lower_sample_pos(nir_shader *s);
+bool dxil_nir_lower_subgroup_id(nir_shader *s);
+bool dxil_nir_lower_num_subgroups(nir_shader *s);
 
 #ifdef __cplusplus
 }