microsoft/compiler: Un-lower shared/scratch to derefs
authorJesse Natalie <jenatali@microsoft.com>
Fri, 19 May 2023 17:28:08 +0000 (10:28 -0700)
committerMarge Bot <emma+marge@anholt.net>
Tue, 13 Jun 2023 00:43:36 +0000 (00:43 +0000)
Derefs have index-based access semantics, which means we don't need
custom intrinsics to encode an index instead of a byte offset.

Remove the "masked" store intrinsics and just emit the pair of atomics
directly. This massively reduces duplication between scratch, shared,
and constant, while also moving more things into nir so more optimizations
can be done.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23173>

src/compiler/nir/nir_intrinsics.py
src/gallium/drivers/d3d12/d3d12_compiler.cpp
src/microsoft/clc/clc_compiler.c
src/microsoft/compiler/dxil_module.c
src/microsoft/compiler/dxil_module.h
src/microsoft/compiler/dxil_nir.c
src/microsoft/compiler/dxil_nir.h
src/microsoft/compiler/nir_to_dxil.c
src/microsoft/spirv_to_dxil/dxil_spirv_nir.c

index 9a30da6..9591a1a 100644 (file)
@@ -1259,25 +1259,9 @@ intrinsic("copy_ubo_to_uniform_ir3", [1, 1], indices=[BASE, RANGE])
 # DXIL specific intrinsics
 # src[] = { value, mask, index, offset }.
 intrinsic("store_ssbo_masked_dxil", [1, 1, 1, 1])
-# src[] = { value, index }.
-intrinsic("store_shared_dxil", [1, 1])
-# src[] = { value, mask, index }.
-intrinsic("store_shared_masked_dxil", [1, 1, 1])
-# src[] = { value, index }.
-intrinsic("store_scratch_dxil", [1, 1])
-# src[] = { index }.
-load("shared_dxil", [1], [], [CAN_ELIMINATE])
-# src[] = { index }.
-load("scratch_dxil", [1], [], [CAN_ELIMINATE])
 # src[] = { index, 16-byte-based-offset }
 load("ubo_dxil", [1, 1], [], [CAN_ELIMINATE, CAN_REORDER])
 
-# DXIL Shared atomic intrinsics
-#
-# src0 is the index in the i32 array for by the shared memory region
-intrinsic("shared_atomic_dxil",  src_comp=[1, 1], dest_comp=1, indices=[ATOMIC_OP])
-intrinsic("shared_atomic_swap_dxil", src_comp=[1, 1, 1], dest_comp=1, indices=[ATOMIC_OP])
-
 # Intrinsics used by the Midgard/Bifrost blend pipeline. These are defined
 # within a blend shader to read/write the raw value from the tile buffer,
 # without applying any format conversion in the process. If the shader needs
index befe646..34d650e 100644 (file)
@@ -144,7 +144,6 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel,
    NIR_PASS_V(nir, d3d12_lower_state_vars, shader);
    const struct dxil_nir_lower_loads_stores_options loads_stores_options = {};
    NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
-   NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
    NIR_PASS_V(nir, dxil_nir_lower_double_math);
 
    if (key->stage == PIPE_SHADER_FRAGMENT && key->fs.multisample_disabled)
index f29f6cd..fc73b2e 100644 (file)
@@ -976,30 +976,8 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
    const struct dxil_nir_lower_loads_stores_options loads_stores_options = {
       .use_16bit_ssbo = false,
    };
-   NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
-   NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
-   NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
-   NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
-   NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
-
-   // Convert pack to pack_split
-   NIR_PASS_V(nir, nir_lower_pack);
-   // Lower pack_split to bit math
-   NIR_PASS_V(nir, nir_opt_algebraic);
-
-   NIR_PASS_V(nir, nir_opt_dce);
-
-   nir_validate_shader(nir, "Validate before feeding NIR to the DXIL compiler");
-   struct nir_to_dxil_options opts = {
-      .interpolate_at_vertex = false,
-      .lower_int16 = (conf && (conf->lower_bit_size & 16) != 0),
-      .disable_math_refactoring = true,
-      .num_kernel_globals = num_global_inputs,
-      .environment = DXIL_ENVIRONMENT_CL,
-      .shader_model_max = conf && conf->max_shader_model ? conf->max_shader_model : SHADER_MODEL_6_2,
-      .validator_version_max = conf ? conf->validator_version : DXIL_VALIDATOR_1_4,
-   };
-
+   
+   /* Now that function-declared local vars have been sized, append args */
    for (unsigned i = 0; i < out_dxil->kernel->num_args; i++) {
       if (out_dxil->kernel->args[i].address_qualifier != CLC_KERNEL_ARG_ADDRESS_LOCAL)
          continue;
@@ -1026,6 +1004,29 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
       nir->info.shared_size += size;
    }
 
+   NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
+   NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
+   NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
+   NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
+
+   // Convert pack to pack_split
+   NIR_PASS_V(nir, nir_lower_pack);
+   // Lower pack_split to bit math
+   NIR_PASS_V(nir, nir_opt_algebraic);
+
+   NIR_PASS_V(nir, nir_opt_dce);
+
+   nir_validate_shader(nir, "Validate before feeding NIR to the DXIL compiler");
+   struct nir_to_dxil_options opts = {
+      .interpolate_at_vertex = false,
+      .lower_int16 = (conf && (conf->lower_bit_size & 16) != 0),
+      .disable_math_refactoring = true,
+      .num_kernel_globals = num_global_inputs,
+      .environment = DXIL_ENVIRONMENT_CL,
+      .shader_model_max = conf && conf->max_shader_model ? conf->max_shader_model : SHADER_MODEL_6_2,
+      .validator_version_max = conf ? conf->validator_version : DXIL_VALIDATOR_1_4,
+   };
+
    metadata->local_mem_size = nir->info.shared_size;
    metadata->priv_mem_size = nir->scratch_size;
 
index 39849b0..8b13c04 100644 (file)
@@ -3372,11 +3372,10 @@ dxil_emit_extractval(struct dxil_module *m, const struct dxil_value *src,
 
 const struct dxil_value *
 dxil_emit_alloca(struct dxil_module *m, const struct dxil_type *alloc_type,
-                 const struct dxil_type *size_type,
                  const struct dxil_value *size,
                  unsigned int align)
 {
-   assert(size_type && size_type->type == TYPE_INTEGER);
+   assert(size->type->type == TYPE_INTEGER);
 
    const struct dxil_type *return_type =
       dxil_module_get_pointer_type(m, alloc_type);
@@ -3388,7 +3387,7 @@ dxil_emit_alloca(struct dxil_module *m, const struct dxil_type *alloc_type,
       return NULL;
 
    instr->alloca.alloc_type = alloc_type;
-   instr->alloca.size_type = size_type;
+   instr->alloca.size_type = size->type;
    instr->alloca.size = size;
    instr->alloca.align = util_logbase2(align) + 1;
    assert(instr->alloca.align < (1 << 5));
index 7d34458..f07603d 100644 (file)
@@ -542,7 +542,6 @@ dxil_emit_ret_void(struct dxil_module *m);
 
 const struct dxil_value *
 dxil_emit_alloca(struct dxil_module *m, const struct dxil_type *alloc_type,
-                 const struct dxil_type *size_type,
                  const struct dxil_value *size,
                  unsigned int align);
 
index 36e0a3d..454b80a 100644 (file)
@@ -332,23 +332,8 @@ lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr, unsigned min_bit_siz
    return true;
 }
 
-static void
-lower_load_vec32(nir_builder *b, nir_ssa_def *index, unsigned num_comps, nir_ssa_def **comps, nir_intrinsic_op op)
-{
-   for (unsigned i = 0; i < num_comps; i++) {
-      nir_intrinsic_instr *load =
-         nir_intrinsic_instr_create(b->shader, op);
-
-      load->num_components = 1;
-      load->src[0] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
-      nir_ssa_dest_init(&load->instr, &load->dest, 1, 32);
-      nir_builder_instr_insert(b, &load->instr);
-      comps[i] = &load->dest.ssa;
-   }
-}
-
 static bool
-lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
+lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
 {
    assert(intr->dest.is_ssa);
    unsigned bit_size = nir_dest_bit_size(intr->dest);
@@ -356,17 +341,13 @@ lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
    unsigned num_bits = num_components * bit_size;
 
    b->cursor = nir_before_instr(&intr->instr);
-   nir_intrinsic_op op = intr->intrinsic;
 
    assert(intr->src[0].is_ssa);
    nir_ssa_def *offset = intr->src[0].ssa;
-   if (op == nir_intrinsic_load_shared) {
+   if (intr->intrinsic == nir_intrinsic_load_shared)
       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
-      op = nir_intrinsic_load_shared_dxil;
-   } else {
+   else
       offset = nir_u2u32(b, offset);
-      op = nir_intrinsic_load_scratch_dxil;
-   }
    nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
    nir_ssa_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
@@ -375,7 +356,8 @@ lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
     * is an i32 array and DXIL does not support type casts.
     */
    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
-   lower_load_vec32(b, index, num_32bit_comps, comps_32bit, op);
+   for (unsigned i = 0; i < num_32bit_comps; i++)
+      comps_32bit[i] = nir_load_array_var(b, var, nir_iadd_imm(b, index, i));
    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
 
    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
@@ -408,22 +390,8 @@ lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
 }
 
 static void
-lower_store_vec32(nir_builder *b, nir_ssa_def *index, nir_ssa_def *vec32, nir_intrinsic_op op)
-{
-   for (unsigned i = 0; i < vec32->num_components; i++) {
-      nir_intrinsic_instr *store =
-         nir_intrinsic_instr_create(b->shader, op);
-
-      store->src[0] = nir_src_for_ssa(nir_channel(b, vec32, i));
-      store->src[1] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
-      store->num_components = 1;
-      nir_builder_instr_insert(b, &store->instr);
-   }
-}
-
-static void
 lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index,
-                         nir_ssa_def *vec32, unsigned num_bits, nir_intrinsic_op op, unsigned alignment)
+                         nir_ssa_def *vec32, unsigned num_bits, nir_variable *var, unsigned alignment)
 {
    nir_ssa_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
 
@@ -436,24 +404,26 @@ lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index
       mask = nir_ishl(b, mask, shift);
    }
 
-   if (op == nir_intrinsic_store_shared_dxil) {
+   if (var->data.mode == nir_var_mem_shared) {
       /* Use the dedicated masked intrinsic */
-      nir_store_shared_masked_dxil(b, vec32, nir_inot(b, mask), index);
+      nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
+      nir_build_deref_atomic(b, 32, &deref->dest.ssa, nir_inot(b, mask), .atomic_op = nir_atomic_op_iand);
+      nir_build_deref_atomic(b, 32, &deref->dest.ssa, vec32, .atomic_op = nir_atomic_op_ior);
    } else {
       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
-      nir_ssa_def *load = nir_load_scratch_dxil(b, 1, 32, index);
+      nir_ssa_def *load = nir_load_array_var(b, var, index);
 
       nir_ssa_def *new_val = nir_ior(b, vec32,
                                      nir_iand(b,
                                               nir_inot(b, mask),
                                               load));
 
-      lower_store_vec32(b, index, new_val, op);
+      nir_store_array_var(b, var, index, new_val, 1);
    }
 }
 
 static bool
-lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
+lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
 {
    assert(intr->src[0].is_ssa);
    unsigned num_components = nir_src_num_components(intr->src[0]);
@@ -461,16 +431,12 @@ lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
    unsigned num_bits = num_components * bit_size;
 
    b->cursor = nir_before_instr(&intr->instr);
-   nir_intrinsic_op op = intr->intrinsic;
 
    nir_ssa_def *offset = intr->src[1].ssa;
-   if (op == nir_intrinsic_store_shared) {
+   if (intr->intrinsic == nir_intrinsic_store_shared)
       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
-      op = nir_intrinsic_store_shared_dxil;
-   } else {
+   else
       offset = nir_u2u32(b, offset);
-      op = nir_intrinsic_store_scratch_dxil;
-   }
    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
 
    unsigned comp_idx = 0;
@@ -489,9 +455,10 @@ lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
       /* For anything less than 32bits we need to use the masked version of the
        * intrinsic to preserve data living in the same 32bit slot. */
       if (substore_num_bits < 32) {
-         lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, op, nir_intrinsic_align(intr));
+         lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, var, nir_intrinsic_align(intr));
       } else {
-         lower_store_vec32(b, index, vec32, op);
+         for (unsigned i = 0; i < vec32->num_components; ++i)
+            nir_store_array_var(b, var, nir_iadd_imm(b, index, i), nir_channel(b, vec32, i), 1);
       }
 
       comp_idx += substore_num_bits / bit_size;
@@ -941,20 +908,60 @@ lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
    return true;
 }
 
+static bool
+lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
+{
+   b->cursor = nir_before_instr(&intr->instr);
+
+   assert(intr->src[0].is_ssa);
+   nir_ssa_def *offset =
+      nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
+   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
+
+   nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
+   nir_ssa_def *result;
+   if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
+      result = nir_build_deref_atomic_swap(b, 32, &deref->dest.ssa, intr->src[1].ssa, intr->src[2].ssa,
+                                           .atomic_op = nir_intrinsic_atomic_op(intr));
+   else
+      result = nir_build_deref_atomic(b, 32, &deref->dest.ssa, intr->src[1].ssa,
+                                      .atomic_op = nir_intrinsic_atomic_op(intr));
+
+   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
+   nir_instr_remove(&intr->instr);
+   return true;
+}
+
 bool
 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
                                     const struct dxil_nir_lower_loads_stores_options *options)
 {
-   bool progress = false;
+   bool progress = nir_remove_dead_variables(nir, nir_var_function_temp | nir_var_mem_shared, NULL);
+   nir_variable *shared_var = NULL;
+   if (nir->info.shared_size) {
+      shared_var = nir_variable_create(nir, nir_var_mem_shared,
+                                       glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->info.shared_size, 4), 4),
+                                       "lowered_shared_mem");
+   }
 
-   foreach_list_typed(nir_function, func, node, &nir->functions) {
-      if (!func->is_entrypoint)
+   unsigned ptr_size = nir->info.cs.ptr_size;
+   if (nir->info.stage == MESA_SHADER_KERNEL) {
+      /* All the derefs created here will be used as GEP indices so force 32-bit */
+      nir->info.cs.ptr_size = 32;
+   }
+   nir_foreach_function(func, nir) {
+      if (!func->impl)
          continue;
-      assert(func->impl);
 
       nir_builder b;
       nir_builder_init(&b, func->impl);
 
+      nir_variable *scratch_var = NULL;
+      if (nir->scratch_size) {
+         const struct glsl_type *scratch_type = glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->scratch_size, 4), 4);
+         scratch_var = nir_local_variable_create(func->impl, scratch_type, "lowered_scratch_mem");
+      }
+
       nir_foreach_block(block, func->impl) {
          nir_foreach_instr_safe(instr, block) {
             if (instr->type != nir_instr_type_intrinsic)
@@ -963,8 +970,10 @@ dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
 
             switch (intr->intrinsic) {
             case nir_intrinsic_load_shared:
+               progress |= lower_32b_offset_load(&b, intr, shared_var);
+               break;
             case nir_intrinsic_load_scratch:
-               progress |= lower_32b_offset_load(&b, intr);
+               progress |= lower_32b_offset_load(&b, intr, scratch_var);
                break;
             case nir_intrinsic_load_ssbo:
                progress |= lower_load_ssbo(&b, intr, options->use_16bit_ssbo ? 16 : 32);
@@ -973,75 +982,17 @@ dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
                progress |= lower_load_ubo(&b, intr);
                break;
             case nir_intrinsic_store_shared:
+               progress |= lower_32b_offset_store(&b, intr, shared_var);
+               break;
             case nir_intrinsic_store_scratch:
-               progress |= lower_32b_offset_store(&b, intr);
+               progress |= lower_32b_offset_store(&b, intr, scratch_var);
                break;
             case nir_intrinsic_store_ssbo:
                progress |= lower_store_ssbo(&b, intr, options->use_16bit_ssbo ? 16 : 32);
                break;
-            default:
-               break;
-            }
-         }
-      }
-   }
-
-   return progress;
-}
-
-static bool
-lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr)
-{
-   b->cursor = nir_before_instr(&intr->instr);
-
-   assert(intr->src[0].is_ssa);
-   nir_ssa_def *offset =
-      nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
-   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
-
-   nir_intrinsic_op dxil_op = intr->intrinsic == nir_intrinsic_shared_atomic_swap ?
-      nir_intrinsic_shared_atomic_swap_dxil : nir_intrinsic_shared_atomic_dxil;
-   nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, dxil_op);
-   atomic->src[0] = nir_src_for_ssa(index);
-   assert(intr->src[1].is_ssa);
-   atomic->src[1] = nir_src_for_ssa(intr->src[1].ssa);
-   if (dxil_op == nir_intrinsic_shared_atomic_swap_dxil) {
-      assert(intr->src[2].is_ssa);
-      atomic->src[2] = nir_src_for_ssa(intr->src[2].ssa);
-   }
-   atomic->num_components = 0;
-   nir_ssa_dest_init(&atomic->instr, &atomic->dest, 1, 32);
-   nir_intrinsic_set_atomic_op(atomic, nir_intrinsic_atomic_op(intr));
-
-   nir_builder_instr_insert(b, &atomic->instr);
-   nir_ssa_def_rewrite_uses(&intr->dest.ssa, &atomic->dest.ssa);
-   nir_instr_remove(&intr->instr);
-   return true;
-}
-
-bool
-dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
-{
-   bool progress = false;
-
-   foreach_list_typed(nir_function, func, node, &nir->functions) {
-      if (!func->is_entrypoint)
-         continue;
-      assert(func->impl);
-
-      nir_builder b;
-      nir_builder_init(&b, func->impl);
-
-      nir_foreach_block(block, func->impl) {
-         nir_foreach_instr_safe(instr, block) {
-            if (instr->type != nir_instr_type_intrinsic)
-               continue;
-            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
-
-            switch (intr->intrinsic) {
             case nir_intrinsic_shared_atomic:
             case nir_intrinsic_shared_atomic_swap:
-               progress |= lower_shared_atomic(&b, intr);
+               progress |= lower_shared_atomic(&b, intr, shared_var);
                break;
             default:
                break;
@@ -1049,6 +1000,9 @@ dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
          }
       }
    }
+   if (nir->info.stage == MESA_SHADER_KERNEL) {
+      nir->info.cs.ptr_size = ptr_size;
+   }
 
    return progress;
 }
index 44789fd..03651db 100644 (file)
@@ -45,7 +45,6 @@ struct dxil_nir_lower_loads_stores_options {
 };
 bool dxil_nir_lower_loads_stores_to_dxil(nir_shader *shader,
                                          const struct dxil_nir_lower_loads_stores_options *options);
-bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader);
 bool dxil_nir_lower_deref_ssbo(nir_shader *shader);
 bool dxil_nir_opt_alu_deref_srcs(nir_shader *shader);
 bool dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size);
index ea92c6f..204e403 100644 (file)
@@ -153,7 +153,7 @@ nir_options = {
       nir_lower_dceil |
       nir_lower_dround_even,
    .max_unroll_iterations = 32, /* arbitrary */
-   .force_indirect_unrolling = (nir_var_shader_in | nir_var_shader_out | nir_var_function_temp),
+   .force_indirect_unrolling = (nir_var_shader_in | nir_var_shader_out),
    .lower_device_index_to_zero = true,
    .linker_ignore_precision = true,
    .support_16bit_alu = true,
@@ -582,8 +582,8 @@ struct ntd_context {
    unsigned num_defs;
    struct hash_table *phis;
 
-   const struct dxil_value *sharedvars;
-   const struct dxil_value *scratchvars;
+   const struct dxil_value **sharedvars;
+   const struct dxil_value **scratchvars;
    const struct dxil_value **consts;
 
    nir_variable *ps_front_face;
@@ -1530,6 +1530,31 @@ emit_global_consts(struct ntd_context *ctx)
 }
 
 static bool
+emit_shared_vars(struct ntd_context *ctx)
+{
+   uint32_t index = 0;
+   nir_foreach_variable_with_modes(var, ctx->shader, nir_var_mem_shared)
+      var->data.driver_location = index++;
+
+   ctx->sharedvars = ralloc_array(ctx->ralloc_ctx, const struct dxil_value *, index);
+
+   nir_foreach_variable_with_modes(var, ctx->shader, nir_var_mem_shared) {
+      if (!var->name)
+         var->name = ralloc_asprintf(var, "shared_%d", var->data.driver_location);
+      const struct dxil_value *gvar = dxil_add_global_ptr_var(&ctx->mod, var->name,
+                                                              get_type_for_glsl_type(&ctx->mod, var->type),
+                                                              DXIL_AS_GROUPSHARED, 16,
+                                                              NULL);
+      if (!gvar)
+         return false;
+
+      ctx->sharedvars[var->data.driver_location] = gvar;
+   }
+
+   return true;
+}
+
+static bool
 emit_cbv(struct ntd_context *ctx, unsigned binding, unsigned space,
          unsigned size, unsigned count, char *name)
 {
@@ -3532,92 +3557,6 @@ emit_store_ssbo_masked(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 }
 
 static bool
-emit_store_shared(struct ntd_context *ctx, nir_intrinsic_instr *intr)
-{
-   const struct dxil_value *zero, *index;
-
-   /* All shared mem accesses should have been lowered to scalar 32bit
-    * accesses.
-    */
-   assert(nir_src_bit_size(intr->src[0]) == 32);
-   assert(nir_src_num_components(intr->src[0]) == 1);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   if (intr->intrinsic == nir_intrinsic_store_shared_dxil)
-      index = get_src(ctx, &intr->src[1], 0, nir_type_uint);
-   else
-      index = get_src(ctx, &intr->src[2], 0, nir_type_uint);
-   if (!index)
-      return false;
-
-   const struct dxil_value *ops[] = { ctx->sharedvars, zero, index };
-   const struct dxil_value *ptr, *value;
-
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
-   if (!ptr)
-      return false;
-
-   value = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!value)
-      return false;
-
-   if (intr->intrinsic == nir_intrinsic_store_shared_dxil)
-      return dxil_emit_store(&ctx->mod, value, ptr, 4, false);
-
-   const struct dxil_value *mask = get_src(ctx, &intr->src[1], 0, nir_type_uint);
-   if (!mask)
-      return false;
-
-   if (!dxil_emit_atomicrmw(&ctx->mod, mask, ptr, DXIL_RMWOP_AND, false,
-                            DXIL_ATOMIC_ORDERING_ACQREL,
-                            DXIL_SYNC_SCOPE_CROSSTHREAD))
-      return false;
-
-   if (!dxil_emit_atomicrmw(&ctx->mod, value, ptr, DXIL_RMWOP_OR, false,
-                            DXIL_ATOMIC_ORDERING_ACQREL,
-                            DXIL_SYNC_SCOPE_CROSSTHREAD))
-      return false;
-
-   return true;
-}
-
-static bool
-emit_store_scratch(struct ntd_context *ctx, nir_intrinsic_instr *intr)
-{
-   const struct dxil_value *zero, *index;
-
-   /* All scratch mem accesses should have been lowered to scalar 32bit
-    * accesses.
-    */
-   assert(nir_src_bit_size(intr->src[0]) == 32);
-   assert(nir_src_num_components(intr->src[0]) == 1);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   index = get_src(ctx, &intr->src[1], 0, nir_type_uint);
-   if (!index)
-      return false;
-
-   const struct dxil_value *ops[] = { ctx->scratchvars, zero, index };
-   const struct dxil_value *ptr, *value;
-
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
-   if (!ptr)
-      return false;
-
-   value = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!value)
-      return false;
-
-   return dxil_emit_store(&ctx->mod, value, ptr, 4, false);
-}
-
-static bool
 emit_load_ubo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 {
    const struct dxil_value* handle = get_resource_handle(ctx, &intr->src[0], DXIL_RESOURCE_CLASS_CBV, DXIL_RESOURCE_KIND_CBUFFER);
@@ -4020,7 +3959,14 @@ deref_to_gep(struct ntd_context *ctx, nir_deref_instr *deref)
                                                        const struct dxil_value *,
                                                        count + 1);
    nir_variable *var = path.path[0]->var;
-   gep_indices[0] = ctx->consts[var->data.driver_location];
+   const struct dxil_value **var_array;
+   switch (deref->modes) {
+   case nir_var_mem_constant: var_array = ctx->consts; break;
+   case nir_var_mem_shared: var_array = ctx->sharedvars; break;
+   case nir_var_function_temp: var_array = ctx->scratchvars; break;
+   default: unreachable("Invalid deref mode");
+   }
+   gep_indices[0] = var_array[var->data.driver_location];
 
    for (uint32_t i = 0; i < count; ++i)
       gep_indices[i + 1] = get_src_ssa(ctx, &path.path[i]->dest.ssa, 0);
@@ -4045,34 +3991,32 @@ emit_load_deref(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 }
 
 static bool
-emit_load_shared(struct ntd_context *ctx, nir_intrinsic_instr *intr)
+emit_store_deref(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 {
-   const struct dxil_value *zero, *index;
-   unsigned bit_size = nir_dest_bit_size(intr->dest);
-   unsigned align = bit_size / 8;
-
-   /* All shared mem accesses should have been lowered to scalar 32bit
-    * accesses.
-    */
-   assert(bit_size == 32);
-   assert(nir_dest_num_components(intr->dest) == 1);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   index = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!index)
+   nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
+   const struct dxil_value *ptr = deref_to_gep(ctx, deref);
+   if (!ptr)
       return false;
 
-   const struct dxil_value *ops[] = { ctx->sharedvars, zero, index };
-   const struct dxil_value *ptr, *retval;
+   const struct dxil_value *value = get_src(ctx, &intr->src[1], 0, nir_get_nir_type_for_glsl_type(deref->type));
+   return dxil_emit_store(&ctx->mod, value, ptr, nir_src_bit_size(intr->src[1]) / 8, false);
+}
 
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
+static bool
+emit_atomic_deref(struct ntd_context *ctx, nir_intrinsic_instr *intr)
+{
+   const struct dxil_value *ptr = deref_to_gep(ctx, nir_src_as_deref(intr->src[0]));
    if (!ptr)
       return false;
+   
+   const struct dxil_value *value = get_src(ctx, &intr->src[1], 0, nir_type_uint);
+   if (!value)
+      return false;
 
-   retval = dxil_emit_load(&ctx->mod, ptr, align, false);
+   enum dxil_rmw_op dxil_op = nir_atomic_to_dxil_rmw(nir_intrinsic_atomic_op(intr));
+   const struct dxil_value *retval = dxil_emit_atomicrmw(&ctx->mod, value, ptr, dxil_op, false,
+                                                         DXIL_ATOMIC_ORDERING_ACQREL,
+                                                         DXIL_SYNC_SCOPE_CROSSTHREAD);
    if (!retval)
       return false;
 
@@ -4081,34 +4025,20 @@ emit_load_shared(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 }
 
 static bool
-emit_load_scratch(struct ntd_context *ctx, nir_intrinsic_instr *intr)
+emit_atomic_deref_swap(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 {
-   const struct dxil_value *zero, *index;
-   unsigned bit_size = nir_dest_bit_size(intr->dest);
-   unsigned align = bit_size / 8;
-
-   /* All scratch mem accesses should have been lowered to scalar 32bit
-    * accesses.
-    */
-   assert(bit_size == 32);
-   assert(nir_dest_num_components(intr->dest) == 1);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   index = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!index)
+   const struct dxil_value *ptr = deref_to_gep(ctx, nir_src_as_deref(intr->src[0]));
+   if (!ptr)
       return false;
 
-   const struct dxil_value *ops[] = { ctx->scratchvars, zero, index };
-   const struct dxil_value *ptr, *retval;
-
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
-   if (!ptr)
+   const struct dxil_value *cmp = get_src(ctx, &intr->src[1], 0, nir_type_uint);
+   const struct dxil_value *value = get_src(ctx, &intr->src[2], 0, nir_type_uint);
+   if (!value)
       return false;
 
-   retval = dxil_emit_load(&ctx->mod, ptr, align, false);
+   const struct dxil_value *retval = dxil_emit_cmpxchg(&ctx->mod, cmp, value, ptr, false,
+                                                       DXIL_ATOMIC_ORDERING_ACQREL,
+                                                       DXIL_SYNC_SCOPE_CROSSTHREAD);
    if (!retval)
       return false;
 
@@ -4572,82 +4502,6 @@ emit_ssbo_atomic_comp_swap(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 }
 
 static bool
-emit_shared_atomic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
-{
-   const struct dxil_value *zero, *index;
-
-   assert(nir_src_bit_size(intr->src[1]) == 32);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   index = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!index)
-      return false;
-
-   const struct dxil_value *ops[] = { ctx->sharedvars, zero, index };
-   const struct dxil_value *ptr, *value, *retval;
-
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
-   if (!ptr)
-      return false;
-
-   nir_atomic_op nir_op = nir_intrinsic_atomic_op(intr);
-   enum dxil_rmw_op dxil_op = nir_atomic_to_dxil_rmw(nir_op);
-   nir_alu_type type = nir_atomic_op_type(nir_op);
-   value = get_src(ctx, &intr->src[1], 0, type);
-   if (!value)
-      return false;
-
-   retval = dxil_emit_atomicrmw(&ctx->mod, value, ptr, dxil_op, false,
-                                DXIL_ATOMIC_ORDERING_ACQREL,
-                                DXIL_SYNC_SCOPE_CROSSTHREAD);
-   if (!retval)
-      return false;
-
-   store_dest(ctx, &intr->dest, 0, retval);
-   return true;
-}
-
-static bool
-emit_shared_atomic_comp_swap(struct ntd_context *ctx, nir_intrinsic_instr *intr)
-{
-   const struct dxil_value *zero, *index;
-
-   assert(nir_src_bit_size(intr->src[1]) == 32);
-
-   zero = dxil_module_get_int32_const(&ctx->mod, 0);
-   if (!zero)
-      return false;
-
-   index = get_src(ctx, &intr->src[0], 0, nir_type_uint);
-   if (!index)
-      return false;
-
-   const struct dxil_value *ops[] = { ctx->sharedvars, zero, index };
-   const struct dxil_value *ptr, *cmpval, *newval, *retval;
-
-   ptr = dxil_emit_gep_inbounds(&ctx->mod, ops, ARRAY_SIZE(ops));
-   if (!ptr)
-      return false;
-
-   cmpval = get_src(ctx, &intr->src[1], 0, nir_type_uint);
-   newval = get_src(ctx, &intr->src[2], 0, nir_type_uint);
-   if (!cmpval || !newval)
-      return false;
-
-   retval = dxil_emit_cmpxchg(&ctx->mod, cmpval, newval, ptr, false,
-                              DXIL_ATOMIC_ORDERING_ACQREL,
-                              DXIL_SYNC_SCOPE_CROSSTHREAD);
-   if (!retval)
-      return false;
-
-   store_dest(ctx, &intr->dest, 0, retval);
-   return true;
-}
-
-static bool
 emit_vulkan_resource_index(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 {
    unsigned int binding = nir_intrinsic_binding(intr);
@@ -5016,13 +4870,14 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
       return emit_store_ssbo(ctx, intr);
    case nir_intrinsic_store_ssbo_masked_dxil:
       return emit_store_ssbo_masked(ctx, intr);
-   case nir_intrinsic_store_shared_dxil:
-   case nir_intrinsic_store_shared_masked_dxil:
-      return emit_store_shared(ctx, intr);
-   case nir_intrinsic_store_scratch_dxil:
-      return emit_store_scratch(ctx, intr);
    case nir_intrinsic_load_deref:
       return emit_load_deref(ctx, intr);
+   case nir_intrinsic_store_deref:
+      return emit_store_deref(ctx, intr);
+   case nir_intrinsic_deref_atomic:
+      return emit_atomic_deref(ctx, intr);
+   case nir_intrinsic_deref_atomic_swap:
+      return emit_atomic_deref_swap(ctx, intr);
    case nir_intrinsic_load_ubo:
       return emit_load_ubo(ctx, intr);
    case nir_intrinsic_load_ubo_dxil:
@@ -5052,10 +4907,6 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
       return emit_load_sample_mask_in(ctx, intr);
    case nir_intrinsic_load_tess_coord:
       return emit_load_tess_coord(ctx, intr);
-   case nir_intrinsic_load_shared_dxil:
-      return emit_load_shared(ctx, intr);
-   case nir_intrinsic_load_scratch_dxil:
-      return emit_load_scratch(ctx, intr);
    case nir_intrinsic_discard_if:
    case nir_intrinsic_demote_if:
       return emit_discard_if(ctx, intr);
@@ -5072,10 +4923,6 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
       return emit_ssbo_atomic(ctx, intr);
    case nir_intrinsic_ssbo_atomic_swap:
       return emit_ssbo_atomic_comp_swap(ctx, intr);
-   case nir_intrinsic_shared_atomic_dxil:
-      return emit_shared_atomic(ctx, intr);
-   case nir_intrinsic_shared_atomic_swap_dxil:
-      return emit_shared_atomic_comp_swap(ctx, intr);
    case nir_intrinsic_image_deref_atomic:
    case nir_intrinsic_image_atomic:
    case nir_intrinsic_bindless_image_atomic:
@@ -6080,60 +5927,30 @@ emit_cbvs(struct ntd_context *ctx)
 }
 
 static bool
-emit_scratch(struct ntd_context *ctx)
-{
-   if (ctx->shader->scratch_size) {
-      /*
-       * We always allocate an u32 array, no matter the actual variable types.
-       * According to the DXIL spec, the minimum load/store granularity is
-       * 32-bit, anything smaller requires using a read-extract/read-write-modify
-       * approach.
-       */
-      unsigned size = ALIGN_POT(ctx->shader->scratch_size, sizeof(uint32_t));
-      const struct dxil_type *int32 = dxil_module_get_int_type(&ctx->mod, 32);
-      const struct dxil_value *array_length = dxil_module_get_int32_const(&ctx->mod, size / sizeof(uint32_t));
-      if (!int32 || !array_length)
-         return false;
+emit_scratch(struct ntd_context *ctx, nir_function_impl *impl)
+{
+   uint32_t index = 0;
+   nir_foreach_function_temp_variable(var, impl)
+      var->data.driver_location = index++;
 
-      const struct dxil_type *type = dxil_module_get_array_type(
-         &ctx->mod, int32, size / sizeof(uint32_t));
-      if (!type)
-         return false;
+   if (ctx->scratchvars)
+      ralloc_free((void *)ctx->scratchvars);
+
+   ctx->scratchvars = ralloc_array(ctx->ralloc_ctx, const struct dxil_value *, index);
 
-      ctx->scratchvars = dxil_emit_alloca(&ctx->mod, type, int32, array_length, 4);
-      if (!ctx->scratchvars)
+   nir_foreach_function_temp_variable(var, impl) {
+      const struct dxil_type *type = get_type_for_glsl_type(&ctx->mod, var->type);
+      const struct dxil_value *length = dxil_module_get_int32_const(&ctx->mod, 1);
+      const struct dxil_value *ptr = dxil_emit_alloca(&ctx->mod, type, length, 16);
+      if (!ptr)
          return false;
+
+      ctx->scratchvars[var->data.driver_location] = ptr;
    }
 
    return true;
 }
 
-/* The validator complains if we don't have ops that reference a global variable. */
-static bool
-shader_has_shared_ops(struct nir_shader *s)
-{
-   nir_foreach_function(func, s) {
-      if (!func->impl)
-         continue;
-      nir_foreach_block(block, func->impl) {
-         nir_foreach_instr(instr, block) {
-            if (instr->type != nir_instr_type_intrinsic)
-               continue;
-            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-            switch (intrin->intrinsic) {
-            case nir_intrinsic_load_shared_dxil:
-            case nir_intrinsic_store_shared_dxil:
-            case nir_intrinsic_shared_atomic_dxil:
-            case nir_intrinsic_shared_atomic_swap_dxil:
-               return true;
-            default: break;
-            }
-         }
-      }
-   }
-   return false;
-}
-
 static bool
 emit_function(struct ntd_context *ctx, nir_function *func)
 {
@@ -6178,7 +5995,7 @@ emit_function(struct ntd_context *ctx, nir_function *func)
 
    nir_gather_ssa_types(impl, ctx->float_types, ctx->int_types);
 
-   if (!emit_scratch(ctx))
+   if (!emit_scratch(ctx, impl))
       return false;
 
    if (!emit_static_indexing_handles(ctx))
@@ -6243,37 +6060,16 @@ emit_module(struct ntd_context *ctx, const struct nir_to_dxil_options *opts)
       }
    }
 
-   if (ctx->shader->info.shared_size && shader_has_shared_ops(ctx->shader)) {
-      const struct dxil_type *type;
-      unsigned size;
-
-     /*
-      * We always allocate an u32 array, no matter the actual variable types.
-      * According to the DXIL spec, the minimum load/store granularity is
-      * 32-bit, anything smaller requires using a read-extract/read-write-modify
-      * approach. Non-atomic 64-bit accesses are allowed, but the
-      * GEP(cast(gvar, u64[] *), offset) and cast(GEP(gvar, offset), u64 *))
-      * sequences don't seem to be accepted by the DXIL validator when the
-      * pointer is in the groupshared address space, making the 32-bit -> 64-bit
-      * pointer cast impossible.
-      */
-      size = ALIGN_POT(ctx->shader->info.shared_size, sizeof(uint32_t));
-      type = dxil_module_get_array_type(&ctx->mod,
-                                        dxil_module_get_int_type(&ctx->mod, 32),
-                                        size / sizeof(uint32_t));
-      ctx->sharedvars = dxil_add_global_ptr_var(&ctx->mod, "shared", type,
-                                                DXIL_AS_GROUPSHARED,
-                                                ffs(sizeof(uint64_t)),
-                                                NULL);
-   }
+   if (!emit_shared_vars(ctx))
+      return false;
+   if (!emit_global_consts(ctx))
+      return false;
 
    /* UAVs */
    if (ctx->shader->info.stage == MESA_SHADER_KERNEL) {
       if (!emit_globals(ctx, opts->num_kernel_globals))
          return false;
 
-      if (!emit_global_consts(ctx))
-         return false;
    } else if (ctx->opts->environment == DXIL_ENVIRONMENT_VULKAN) {
       /* Handle read/write SSBOs as UAVs */
       nir_foreach_variable_with_modes(var, ctx->shader, nir_var_mem_ssbo) {
@@ -6435,7 +6231,7 @@ optimize_nir(struct nir_shader *s, const struct nir_to_dxil_options *opts)
    do {
       progress = false;
       NIR_PASS_V(s, nir_lower_vars_to_ssa);
-      NIR_PASS(progress, s, nir_lower_indirect_derefs, nir_var_function_temp, UINT32_MAX);
+      NIR_PASS(progress, s, nir_lower_indirect_derefs, nir_var_function_temp, 4);
       NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
       NIR_PASS(progress, s, nir_copy_prop);
       NIR_PASS(progress, s, nir_opt_copy_prop_vars);
@@ -6751,7 +6547,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
    optimize_nir(s, opts);
 
    NIR_PASS_V(s, nir_remove_dead_variables,
-              nir_var_function_temp | nir_var_mem_constant, NULL);
+              nir_var_function_temp | nir_var_mem_constant | nir_var_mem_shared, NULL);
 
    if (!allocate_sysvalues(ctx))
       return false;
index 0f9beaa..90762a6 100644 (file)
@@ -1075,7 +1075,6 @@ dxil_spirv_nir_passes(nir_shader *nir,
    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared,
       nir_address_format_32bit_offset);
 
-   NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
    NIR_PASS_V(nir, dxil_nir_lower_int_cubemaps, false);
 
    NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);