gallivm: Switch to reg intrinsics
authorAlyssa Rosenzweig <alyssa@rosenzweig.io>
Fri, 9 Jun 2023 15:03:09 +0000 (11:03 -0400)
committerMarge Bot <emma+marge@anholt.net>
Wed, 12 Jul 2023 01:34:27 +0000 (01:34 +0000)
This is pretty straightforward, since we don't try to "coalesce" register access
the way a GPU backend would. In the old path, we generated register load/store
instructions internally when hitting register sources/destinations. In the new
path, we just translate the register load/store intrinsics to the LLVM
loads/stores and we're back where we started. It's a bit more code, but it's
more straightforward.

Notably, although this continues to use registers, this does NOT use the chasing
helpers.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Dave Airlie <airlied@redhat.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23089>

src/gallium/auxiliary/gallivm/lp_bld_nir.c
src/gallium/auxiliary/gallivm/lp_bld_nir.h
src/gallium/auxiliary/gallivm/lp_bld_nir_aos.c
src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c

index e276410..33311bf 100644 (file)
@@ -36,6 +36,7 @@
 #include "lp_bld_struct.h"
 #include "lp_bld_debug.h"
 #include "lp_bld_printf.h"
+#include "nir.h"
 #include "nir_deref.h"
 #include "nir_search_helpers.h"
 
@@ -150,36 +151,10 @@ glsl_sampler_to_pipe(int sampler_dim, bool is_array)
 }
 
 
-static LLVMValueRef get_ssa_src(struct lp_build_nir_context *bld_base, nir_ssa_def *ssa)
-{
-   return bld_base->ssa_defs[ssa->index];
-}
-
-
-static LLVMValueRef
-get_src(struct lp_build_nir_context *bld_base, nir_src src);
-
-
-static LLVMValueRef
-get_reg_src(struct lp_build_nir_context *bld_base, nir_register_src src)
-{
-   struct hash_entry *entry = _mesa_hash_table_search(bld_base->regs, src.reg);
-   LLVMValueRef reg_storage = (LLVMValueRef)entry->data;
-   struct lp_build_context *reg_bld = get_int_bld(bld_base, true, src.reg->bit_size);
-   LLVMValueRef indir_src = NULL;
-   if (src.indirect)
-      indir_src = get_src(bld_base, *src.indirect);
-   return bld_base->load_reg(bld_base, reg_bld, &src, indir_src, reg_storage);
-}
-
-
 static LLVMValueRef
 get_src(struct lp_build_nir_context *bld_base, nir_src src)
 {
-   if (src.is_ssa)
-      return get_ssa_src(bld_base, src.ssa);
-   else
-      return get_reg_src(bld_base, src.reg);
+   return bld_base->ssa_defs[src.ssa->index];
 }
 
 
@@ -204,46 +179,6 @@ assign_ssa_dest(struct lp_build_nir_context *bld_base, const nir_ssa_def *ssa,
 }
 
 
-static void
-assign_reg(struct lp_build_nir_context *bld_base, const nir_register_dest *reg,
-           unsigned write_mask,
-           LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS])
-{
-   assert(write_mask != 0x0);
-   struct hash_entry *entry = _mesa_hash_table_search(bld_base->regs, reg->reg);
-   LLVMValueRef reg_storage = (LLVMValueRef)entry->data;
-   struct lp_build_context *reg_bld = get_int_bld(bld_base, true, reg->reg->bit_size);
-   LLVMValueRef indir_src = NULL;
-   if (reg->indirect)
-      indir_src = get_src(bld_base, *reg->indirect);
-   bld_base->store_reg(bld_base, reg_bld, reg, write_mask,
-                       indir_src, reg_storage, vals);
-}
-
-
-static void
-assign_dest(struct lp_build_nir_context *bld_base,
-            const nir_dest *dest,
-            LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS])
-{
-   if (dest->is_ssa)
-      assign_ssa_dest(bld_base, &dest->ssa, vals);
-   else
-      assign_reg(bld_base, &dest->reg, 0xf, vals);
-}
-
-
-static void
-assign_alu_dest(struct lp_build_nir_context *bld_base,
-                const nir_alu_dest *dest,
-                LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS])
-{
-   if (dest->dest.is_ssa)
-      assign_ssa_dest(bld_base, &dest->dest.ssa, vals);
-   else
-      assign_reg(bld_base, &dest->dest.reg, dest->write_mask, vals);
-}
-
 static LLVMValueRef
 fcmp32(struct lp_build_nir_context *bld_base,
        enum pipe_compare_func compare,
@@ -1218,17 +1153,6 @@ visit_alu(struct lp_build_nir_context *bld_base,
       src_bit_size[i] = nir_src_bit_size(instr->src[i].src);
    }
 
-   if (instr->op == nir_op_mov &&
-       is_aos(bld_base) &&
-       !instr->dest.dest.is_ssa) {
-      for (unsigned i = 0; i < 4; i++) {
-         if (instr->dest.write_mask & (1 << i)) {
-            assign_reg(bld_base, &instr->dest.dest.reg, (1 << i), src);
-         }
-      }
-      return;
-   }
-
    LLVMValueRef result[NIR_MAX_VEC_COMPONENTS];
    if (instr->op == nir_op_vec4 ||
        instr->op == nir_op_vec3 ||
@@ -1278,7 +1202,7 @@ visit_alu(struct lp_build_nir_context *bld_base,
                                nir_dest_bit_size(instr->dest.dest));
       }
    }
-   assign_alu_dest(bld_base, &instr->dest, result);
+   assign_ssa_dest(bld_base, &instr->dest.dest.ssa, result);
 }
 
 
@@ -1416,6 +1340,79 @@ visit_store_output(struct lp_build_nir_context *bld_base,
                        bit_size, &var, mask, NULL, 0, indir_index, src);
 }
 
+
+static void
+visit_load_reg(struct lp_build_nir_context *bld_base,
+               nir_intrinsic_instr *instr,
+               LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+
+   nir_intrinsic_instr *decl = nir_reg_get_decl(instr->src[0].ssa);
+   unsigned base = nir_intrinsic_base(instr);
+
+   struct hash_entry *entry = _mesa_hash_table_search(bld_base->regs, decl);
+   LLVMValueRef reg_storage = (LLVMValueRef)entry->data;
+
+   unsigned bit_size = nir_intrinsic_bit_size(decl);
+   struct lp_build_context *reg_bld = get_int_bld(bld_base, true, bit_size);
+
+   LLVMValueRef indir_src = NULL;
+   if (instr->intrinsic == nir_intrinsic_load_reg_indirect) {
+      indir_src = cast_type(bld_base, get_src(bld_base, instr->src[1]),
+                            nir_type_uint, 32);
+   }
+
+   LLVMValueRef val = bld_base->load_reg(bld_base, reg_bld, decl, base, indir_src, reg_storage);
+
+   if (!is_aos(bld_base) && nir_dest_num_components(instr->dest) > 1) {
+      for (unsigned i = 0; i < nir_dest_num_components(instr->dest); i++)
+         result[i] = LLVMBuildExtractValue(builder, val, i, "");
+   } else {
+      result[0] = val;
+   }
+}
+
+
+static void
+visit_store_reg(struct lp_build_nir_context *bld_base,
+                nir_intrinsic_instr *instr)
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+
+   nir_intrinsic_instr *decl = nir_reg_get_decl(instr->src[1].ssa);
+   unsigned base = nir_intrinsic_base(instr);
+   unsigned write_mask = nir_intrinsic_write_mask(instr);
+   assert(write_mask != 0x0);
+
+   LLVMValueRef val = get_src(bld_base, instr->src[0]);
+   LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS] = { NULL };
+   if (!is_aos(bld_base) && nir_src_num_components(instr->src[0]) > 1) {
+      for (unsigned i = 0; i < nir_src_num_components(instr->src[0]); i++)
+         vals[i] = LLVMBuildExtractValue(builder, val, i, "");
+   } else {
+      vals[0] = val;
+   }
+
+   struct hash_entry *entry = _mesa_hash_table_search(bld_base->regs, decl);
+   LLVMValueRef reg_storage = (LLVMValueRef)entry->data;
+
+   unsigned bit_size = nir_intrinsic_bit_size(decl);
+   struct lp_build_context *reg_bld = get_int_bld(bld_base, true, bit_size);
+
+   LLVMValueRef indir_src = NULL;
+   if (instr->intrinsic == nir_intrinsic_store_reg_indirect) {
+      indir_src = cast_type(bld_base, get_src(bld_base, instr->src[2]),
+                            nir_type_uint, 32);
+   }
+
+   bld_base->store_reg(bld_base, reg_bld, decl, write_mask, base,
+                       indir_src, reg_storage, vals);
+}
+
+
 static bool
 compact_array_index_oob(struct lp_build_nir_context *bld_base, nir_variable *var, const uint32_t index)
 {
@@ -2137,6 +2134,17 @@ visit_intrinsic(struct lp_build_nir_context *bld_base,
 {
    LLVMValueRef result[NIR_MAX_VEC_COMPONENTS] = {0};
    switch (instr->intrinsic) {
+   case nir_intrinsic_decl_reg:
+      /* already handled */
+      break;
+   case nir_intrinsic_load_reg:
+   case nir_intrinsic_load_reg_indirect:
+      visit_load_reg(bld_base, instr, result);
+      break;
+   case nir_intrinsic_store_reg:
+   case nir_intrinsic_store_reg_indirect:
+      visit_store_reg(bld_base, instr);
+      break;
    case nir_intrinsic_load_input:
       visit_load_input(bld_base, instr, result);
       break;
@@ -2333,7 +2341,8 @@ visit_intrinsic(struct lp_build_nir_context *bld_base,
       break;
    }
    if (result[0]) {
-      assign_dest(bld_base, &instr->dest, result);
+      assert(instr->dest.is_ssa);
+      assign_ssa_dest(bld_base, &instr->dest.ssa, result);
    }
 }
 
@@ -2379,8 +2388,9 @@ visit_txs(struct lp_build_nir_context *bld_base, nir_tex_instr *instr)
    params.resource = resource;
 
    bld_base->tex_size(bld_base, &params);
-   assign_dest(bld_base, &instr->dest,
-               &sizes_out[instr->op == nir_texop_query_levels ? 3 : 0]);
+   assert(instr->dest.is_ssa);
+   assign_ssa_dest(bld_base, &instr->dest.ssa,
+                   &sizes_out[instr->op == nir_texop_query_levels ? 3 : 0]);
 }
 
 
@@ -2686,7 +2696,8 @@ visit_tex(struct lp_build_nir_context *bld_base, nir_tex_instr *instr)
       }
    }
 
-   assign_dest(bld_base, &instr->dest, texel);
+   assert(instr->dest.is_ssa);
+   assign_ssa_dest(bld_base, &instr->dest.ssa, texel);
 }
 
 
@@ -2855,19 +2866,23 @@ handle_shader_output_decl(struct lp_build_nir_context *bld_base,
 */
 static LLVMTypeRef
 get_register_type(struct lp_build_nir_context *bld_base,
-                  nir_register *reg)
+                  nir_intrinsic_instr *reg)
 {
    if (is_aos(bld_base))
       return bld_base->base.int_vec_type;
 
+   unsigned num_array_elems = nir_intrinsic_num_array_elems(reg);
+   unsigned bit_size = nir_intrinsic_bit_size(reg);
+   unsigned num_components = nir_intrinsic_num_components(reg);
+
    struct lp_build_context *int_bld =
-      get_int_bld(bld_base, true, reg->bit_size == 1 ? 32 : reg->bit_size);
+      get_int_bld(bld_base, true, bit_size == 1 ? 32 : bit_size);
 
    LLVMTypeRef type = int_bld->vec_type;
-   if (reg->num_components > 1)
-      type = LLVMArrayType(type, reg->num_components);
-   if (reg->num_array_elems)
-      type = LLVMArrayType(type, reg->num_array_elems);
+   if (num_components > 1)
+      type = LLVMArrayType(type, num_components);
+   if (num_array_elems)
+      type = LLVMArrayType(type, num_array_elems);
 
    return type;
 }
@@ -2878,14 +2893,14 @@ bool lp_build_nir_llvm(struct lp_build_nir_context *bld_base,
 {
    struct nir_function *func;
 
-   NIR_PASS_V(nir, nir_convert_from_ssa, true, false);
-   NIR_PASS_V(nir, nir_lower_locals_to_regs, 32);
+   NIR_PASS_V(nir, nir_convert_from_ssa, true, true);
+   NIR_PASS_V(nir, nir_lower_locals_to_reg_intrinsics, 32);
    NIR_PASS_V(nir, nir_remove_dead_derefs);
    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL);
 
    if (is_aos(bld_base)) {
       NIR_PASS_V(nir, nir_move_vec_src_uses_to_dest);
-      NIR_PASS_V(nir, nir_lower_vec_to_movs, NULL, NULL);
+      NIR_PASS_V(nir, nir_lower_vec_to_regs, NULL, NULL);
    }
 
    nir_foreach_shader_out_variable(variable, nir)
@@ -2915,7 +2930,7 @@ bool lp_build_nir_llvm(struct lp_build_nir_context *bld_base,
 
    func = (struct nir_function *)exec_list_get_head(&nir->functions);
 
-   nir_foreach_register(reg, &func->impl->registers) {
+   nir_foreach_reg_decl(reg, func->impl) {
       LLVMTypeRef type = get_register_type(bld_base, reg);
       LLVMValueRef reg_alloc = lp_build_alloca(bld_base->base.gallivm,
                                                type, "reg");
index 8de9305..e78b940 100644 (file)
@@ -161,13 +161,15 @@ struct lp_build_nir_context
 
    LLVMValueRef (*load_reg)(struct lp_build_nir_context *bld_base,
                             struct lp_build_context *reg_bld,
-                            const nir_register_src *reg,
+                            const nir_intrinsic_instr *decl,
+                            unsigned base,
                             LLVMValueRef indir_src,
                             LLVMValueRef reg_storage);
    void (*store_reg)(struct lp_build_nir_context *bld_base,
                      struct lp_build_context *reg_bld,
-                     const nir_register_dest *reg,
+                     const nir_intrinsic_instr *decl,
                      unsigned writemask,
+                     unsigned base,
                      LLVMValueRef indir_src,
                      LLVMValueRef reg_storage,
                      LLVMValueRef dst[NIR_MAX_VEC_COMPONENTS]);
index 50c4144..5216131 100644 (file)
@@ -162,11 +162,13 @@ emit_store_var(struct lp_build_nir_context *bld_base,
 static LLVMValueRef
 emit_load_reg(struct lp_build_nir_context *bld_base,
               struct lp_build_context *reg_bld,
-              const nir_register_src *reg,
+              const nir_intrinsic_instr *decl,
+              unsigned base,
               LLVMValueRef indir_src,
               LLVMValueRef reg_storage)
 {
    struct gallivm_state *gallivm = bld_base->base.gallivm;
+   assert(indir_src == NULL && "no indirects with linear path");
    return LLVMBuildLoad2(gallivm->builder, reg_bld->vec_type, reg_storage, "");
 }
 
@@ -205,14 +207,16 @@ swizzle_writemask(struct lp_build_nir_aos_context *bld,
 static void
 emit_store_reg(struct lp_build_nir_context *bld_base,
                struct lp_build_context *reg_bld,
-               const nir_register_dest *reg,
+               const nir_intrinsic_instr *decl,
                unsigned writemask,
+               unsigned base,
                LLVMValueRef indir_src,
                LLVMValueRef reg_storage,
                LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS])
 {
    struct lp_build_nir_aos_context *bld = lp_nir_aos_context(bld_base);
    struct gallivm_state *gallivm = bld_base->base.gallivm;
+   assert(indir_src == NULL && "no indirects with linear path");
 
    if (writemask == 0xf) {
       LLVMBuildStore(gallivm->builder, vals[0], reg_storage);
index 1ece059..b725939 100644 (file)
@@ -827,19 +827,20 @@ static void emit_store_var(struct lp_build_nir_context *bld_base,
  */
 static LLVMValueRef reg_chan_pointer(struct lp_build_nir_context *bld_base,
                                            struct lp_build_context *reg_bld,
-                                           const nir_register *reg,
+                                           const nir_intrinsic_instr *decl,
                                            LLVMValueRef reg_storage,
                                            int array_index, int chan)
 {
    struct gallivm_state *gallivm = bld_base->base.gallivm;
-   int nc = reg->num_components;
+   int nc = nir_intrinsic_num_components(decl);
+   int num_array_elems = nir_intrinsic_num_array_elems(decl);
 
    LLVMTypeRef chan_type = reg_bld->vec_type;
    if (nc > 1)
       chan_type = LLVMArrayType(chan_type, nc);
 
-   if (reg->num_array_elems > 0) {
-      LLVMTypeRef array_type = LLVMArrayType(chan_type, reg->num_array_elems);
+   if (num_array_elems > 0) {
+      LLVMTypeRef array_type = LLVMArrayType(chan_type, num_array_elems);
       reg_storage = lp_build_array_get_ptr2(gallivm, array_type, reg_storage,
                                             lp_build_const_int32(gallivm, array_index));
    }
@@ -853,18 +854,20 @@ static LLVMValueRef reg_chan_pointer(struct lp_build_nir_context *bld_base,
 
 static LLVMValueRef emit_load_reg(struct lp_build_nir_context *bld_base,
                                   struct lp_build_context *reg_bld,
-                                  const nir_register_src *reg,
+                                  const nir_intrinsic_instr *decl,
+                                  unsigned base,
                                   LLVMValueRef indir_src,
                                   LLVMValueRef reg_storage)
 {
    struct gallivm_state *gallivm = bld_base->base.gallivm;
    LLVMBuilderRef builder = gallivm->builder;
-   int nc = reg->reg->num_components;
+   int nc = nir_intrinsic_num_components(decl);
+   int num_array_elems = nir_intrinsic_num_array_elems(decl);
    LLVMValueRef vals[NIR_MAX_VEC_COMPONENTS] = { NULL };
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
-   if (reg->indirect) {
-      LLVMValueRef indirect_val = lp_build_const_int_vec(gallivm, uint_bld->type, reg->base_offset);
-      LLVMValueRef max_index = lp_build_const_int_vec(gallivm, uint_bld->type, reg->reg->num_array_elems - 1);
+   if (indir_src != NULL) {
+      LLVMValueRef indirect_val = lp_build_const_int_vec(gallivm, uint_bld->type, base);
+      LLVMValueRef max_index = lp_build_const_int_vec(gallivm, uint_bld->type, num_array_elems - 1);
       indirect_val = LLVMBuildAdd(builder, indirect_val, indir_src, "");
       indirect_val = lp_build_min(uint_bld, indirect_val, max_index);
       reg_storage = LLVMBuildBitCast(builder, reg_storage, LLVMPointerType(reg_bld->elem_type, 0), "");
@@ -875,8 +878,8 @@ static LLVMValueRef emit_load_reg(struct lp_build_nir_context *bld_base,
    } else {
       for (unsigned i = 0; i < nc; i++) {
          vals[i] = LLVMBuildLoad2(builder, reg_bld->vec_type,
-                                  reg_chan_pointer(bld_base, reg_bld, reg->reg, reg_storage,
-                                                   reg->base_offset, i), "");
+                                  reg_chan_pointer(bld_base, reg_bld, decl, reg_storage,
+                                                   base, i), "");
       }
    }
    return nc == 1 ? vals[0] : lp_nir_array_build_gather_values(builder, vals, nc);
@@ -884,8 +887,9 @@ static LLVMValueRef emit_load_reg(struct lp_build_nir_context *bld_base,
 
 static void emit_store_reg(struct lp_build_nir_context *bld_base,
                            struct lp_build_context *reg_bld,
-                           const nir_register_dest *reg,
+                           const nir_intrinsic_instr *decl,
                            unsigned writemask,
+                           unsigned base,
                            LLVMValueRef indir_src,
                            LLVMValueRef reg_storage,
                            LLVMValueRef dst[NIR_MAX_VEC_COMPONENTS])
@@ -894,10 +898,11 @@ static void emit_store_reg(struct lp_build_nir_context *bld_base,
    struct gallivm_state *gallivm = bld_base->base.gallivm;
    LLVMBuilderRef builder = gallivm->builder;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
-   int nc = reg->reg->num_components;
-   if (reg->indirect) {
-      LLVMValueRef indirect_val = lp_build_const_int_vec(gallivm, uint_bld->type, reg->base_offset);
-      LLVMValueRef max_index = lp_build_const_int_vec(gallivm, uint_bld->type, reg->reg->num_array_elems - 1);
+   int nc = nir_intrinsic_num_components(decl);
+   int num_array_elems = nir_intrinsic_num_array_elems(decl);
+   if (indir_src != NULL) {
+      LLVMValueRef indirect_val = lp_build_const_int_vec(gallivm, uint_bld->type, base);
+      LLVMValueRef max_index = lp_build_const_int_vec(gallivm, uint_bld->type, num_array_elems - 1);
       indirect_val = LLVMBuildAdd(builder, indirect_val, indir_src, "");
       indirect_val = lp_build_min(uint_bld, indirect_val, max_index);
       reg_storage = LLVMBuildBitCast(builder, reg_storage, LLVMPointerType(reg_bld->elem_type, 0), "");
@@ -916,8 +921,8 @@ static void emit_store_reg(struct lp_build_nir_context *bld_base,
          continue;
       dst[i] = LLVMBuildBitCast(builder, dst[i], reg_bld->vec_type, "");
       lp_exec_mask_store(&bld->exec_mask, reg_bld, dst[i],
-                         reg_chan_pointer(bld_base, reg_bld, reg->reg, reg_storage,
-                                          reg->base_offset, i));
+                         reg_chan_pointer(bld_base, reg_bld, decl, reg_storage,
+                                          base, i));
    }
 }