ac/nir: implement sparse image/texture loads
authorRhys Perry <pendingchaos02@gmail.com>
Mon, 23 Nov 2020 15:02:28 +0000 (15:02 +0000)
committerMarge Bot <eric+marge@anholt.net>
Fri, 8 Jan 2021 14:27:07 +0000 (14:27 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7775>

src/amd/llvm/ac_llvm_build.c
src/amd/llvm/ac_llvm_build.h
src/amd/llvm/ac_nir_to_llvm.c
src/gallium/drivers/radeonsi/si_compute_prim_discard.c
src/gallium/drivers/radeonsi/si_shader_llvm_vs.c

index 93d408b..4d8c5da 100644 (file)
@@ -325,8 +325,27 @@ void ac_build_type_name_for_intr(LLVMTypeRef type, char *buf, unsigned bufsize)
 {
    LLVMTypeRef elem_type = type;
 
-   assert(bufsize >= 8);
+   if (LLVMGetTypeKind(type) == LLVMStructTypeKind) {
+      unsigned count = LLVMCountStructElementTypes(type);
+      int ret = snprintf(buf, bufsize, "sl_");
+      buf += ret;
+      bufsize -= ret;
+
+      LLVMTypeRef *elems = alloca(count * sizeof(LLVMTypeRef));
+      LLVMGetStructElementTypes(type, elems);
 
+      for (unsigned i = 0; i < count; i++) {
+         ac_build_type_name_for_intr(elems[i], buf, bufsize);
+         ret = strlen(buf);
+         buf += ret;
+         bufsize -= ret;
+      }
+
+      snprintf(buf, bufsize, "s");
+      return;
+   }
+
+   assert(bufsize >= 8);
    if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
       int ret = snprintf(buf, bufsize, "v%u", LLVMGetVectorSize(type));
       if (ret < 0) {
@@ -566,11 +585,25 @@ LLVMValueRef ac_build_gather_values(struct ac_llvm_context *ctx, LLVMValueRef *v
    return ac_build_gather_values_extended(ctx, values, value_count, 1, false, false);
 }
 
+LLVMValueRef ac_build_concat(struct ac_llvm_context *ctx, LLVMValueRef a, LLVMValueRef b)
+{
+   unsigned a_size = ac_get_llvm_num_components(a);
+   unsigned b_size = ac_get_llvm_num_components(b);
+
+   LLVMValueRef *elems = alloca((a_size + b_size) * sizeof(LLVMValueRef));
+   for (unsigned i = 0; i < a_size; i++)
+      elems[i] = ac_llvm_extract_elem(ctx, a, i);
+   for (unsigned i = 0; i < b_size; i++)
+      elems[a_size + i] = ac_llvm_extract_elem(ctx, b, i);
+
+   return ac_build_gather_values(ctx, elems, a_size + b_size);
+}
+
 /* Expand a scalar or vector to <dst_channels x type> by filling the remaining
  * channels with undef. Extract at most src_channels components from the input.
  */
-static LLVMValueRef ac_build_expand(struct ac_llvm_context *ctx, LLVMValueRef value,
-                                    unsigned src_channels, unsigned dst_channels)
+LLVMValueRef ac_build_expand(struct ac_llvm_context *ctx, LLVMValueRef value,
+                             unsigned src_channels, unsigned dst_channels)
 {
    LLVMTypeRef elemtype;
    LLVMValueRef *const chan = alloca(dst_channels * sizeof(LLVMValueRef));
@@ -1231,8 +1264,42 @@ LLVMValueRef ac_build_buffer_load(struct ac_llvm_context *ctx, LLVMValueRef rsrc
 LLVMValueRef ac_build_buffer_load_format(struct ac_llvm_context *ctx, LLVMValueRef rsrc,
                                          LLVMValueRef vindex, LLVMValueRef voffset,
                                          unsigned num_channels, unsigned cache_policy,
-                                         bool can_speculate, bool d16)
+                                         bool can_speculate, bool d16, bool tfe)
 {
+   if (tfe) {
+      assert(!d16);
+
+      char code[256];
+      /* The definition in the assembly and the one in the constraint string
+       * differs because of an assembler bug.
+       */
+      snprintf(code, sizeof(code),
+               "v_mov_b32 v0, 0\n"
+               "v_mov_b32 v1, 0\n"
+               "v_mov_b32 v2, 0\n"
+               "v_mov_b32 v3, 0\n"
+               "v_mov_b32 v4, 0\n"
+               "buffer_load_format_xyzw v[0:3], $1, $2, 0, idxen offen %s %s tfe %s\n"
+               "s_waitcnt vmcnt(0)",
+               cache_policy & ac_glc ? "glc" : "",
+               cache_policy & ac_slc ? "slc" : "",
+               cache_policy & ac_dlc ? "dlc" : "");
+
+      LLVMTypeRef param_types[] = {ctx->v2i32, ctx->v4i32};
+      LLVMTypeRef calltype = LLVMFunctionType(LLVMVectorType(ctx->f32, 5), param_types, 2, false);
+      LLVMValueRef inlineasm = LLVMConstInlineAsm(calltype, code, "=&{v[0:4]},v,s", false, false);
+
+      LLVMValueRef addr_comp[2] = {vindex ? vindex : ctx->i32_0,
+                                   voffset ? voffset : ctx->i32_0};
+
+      LLVMValueRef args[] = {ac_build_gather_values(ctx, addr_comp, 2),
+                             LLVMBuildBitCast(ctx->builder, rsrc, ctx->v4i32, "")};
+      LLVMValueRef res = LLVMBuildCall(ctx->builder, inlineasm, args, 2, "");
+
+      return ac_build_concat(ctx, ac_trim_vector(ctx, res, num_channels),
+                             ac_llvm_extract_elem(ctx, res, 4));
+   }
+
    return ac_build_buffer_load_common(ctx, rsrc, vindex, voffset, ctx->i32_0, num_channels,
                                       d16 ? ctx->f16 : ctx->f32, cache_policy, can_speculate, true,
                                       true);
@@ -2120,7 +2187,7 @@ LLVMValueRef ac_build_image_opcode(struct ac_llvm_context *ctx, struct ac_image_
    LLVMTypeRef coord_type = sample ? ctx->f32 : ctx->i32;
    uint8_t dmask = a->dmask;
    LLVMTypeRef data_type;
-   char data_type_str[8];
+   char data_type_str[32];
 
    if (atomic) {
       data_type = LLVMTypeOf(a->data[0]);
@@ -2132,6 +2199,11 @@ LLVMValueRef ac_build_image_opcode(struct ac_llvm_context *ctx, struct ac_image_
       data_type = a->d16 ? ctx->v4f16 : ctx->v4f32;
    }
 
+   if (a->tfe) {
+      data_type = LLVMStructTypeInContext(
+         ctx->context, (LLVMTypeRef[]){data_type, ctx->i32}, 2, false);
+   }
+
    if (atomic || a->opcode == ac_image_store || a->opcode == ac_image_store_mip) {
       args[num_args++] = a->data[0];
       if (a->opcode == ac_image_atomic_cmpswap)
@@ -2171,7 +2243,7 @@ LLVMValueRef ac_build_image_opcode(struct ac_llvm_context *ctx, struct ac_image_
       args[num_args++] = LLVMConstInt(ctx->i1, a->unorm, false);
    }
 
-   args[num_args++] = ctx->i32_0; /* texfailctrl */
+   args[num_args++] = a->tfe ? ctx->i32_1 : ctx->i32_0; /* texfailctrl */
    args[num_args++] = LLVMConstInt(
       ctx->i32, load ? get_load_cache_policy(ctx, a->cache_policy) : a->cache_policy, false);
 
@@ -2258,14 +2330,18 @@ LLVMValueRef ac_build_image_opcode(struct ac_llvm_context *ctx, struct ac_image_
             data_type_str, overload[0], overload[1], overload[2]);
 
    LLVMTypeRef retty;
-   if (atomic)
-      retty = data_type;
-   else if (a->opcode == ac_image_store || a->opcode == ac_image_store_mip)
+   if (a->opcode == ac_image_store || a->opcode == ac_image_store_mip)
       retty = ctx->voidt;
    else
-      retty = a->d16 ? ctx->v4f16 : ctx->v4f32;
+      retty = data_type;
 
    LLVMValueRef result = ac_build_intrinsic(ctx, intr_name, retty, args, num_args, a->attributes);
+   if (a->tfe) {
+      LLVMValueRef texel = LLVMBuildExtractValue(ctx->builder, result, 0, "");
+      LLVMValueRef code = LLVMBuildExtractValue(ctx->builder, result, 1, "");
+      result = ac_build_concat(ctx, texel, ac_to_float(ctx, code));
+   }
+
    if (!sample && !atomic && retty != ctx->voidt)
       result = ac_to_integer(ctx, result);
 
index 8c84b38..5a4a61a 100644 (file)
@@ -195,9 +195,14 @@ LLVMValueRef ac_build_gather_values_extended(struct ac_llvm_context *ctx, LLVMVa
 LLVMValueRef ac_build_gather_values(struct ac_llvm_context *ctx, LLVMValueRef *values,
                                     unsigned value_count);
 
+LLVMValueRef ac_build_concat(struct ac_llvm_context *ctx, LLVMValueRef a, LLVMValueRef b);
+
 LLVMValueRef ac_extract_components(struct ac_llvm_context *ctx, LLVMValueRef value, unsigned start,
                                    unsigned channels);
 
+LLVMValueRef ac_build_expand(struct ac_llvm_context *ctx, LLVMValueRef value,
+                             unsigned src_channels, unsigned dst_channels);
+
 LLVMValueRef ac_build_expand_to_vec4(struct ac_llvm_context *ctx, LLVMValueRef value,
                                      unsigned num_channels);
 LLVMValueRef ac_build_round(struct ac_llvm_context *ctx, LLVMValueRef value);
@@ -261,7 +266,7 @@ LLVMValueRef ac_build_buffer_load(struct ac_llvm_context *ctx, LLVMValueRef rsrc
 LLVMValueRef ac_build_buffer_load_format(struct ac_llvm_context *ctx, LLVMValueRef rsrc,
                                          LLVMValueRef vindex, LLVMValueRef voffset,
                                          unsigned num_channels, unsigned cache_policy,
-                                         bool can_speculate, bool d16);
+                                         bool can_speculate, bool d16, bool tfe);
 
 LLVMValueRef ac_build_tbuffer_load_short(struct ac_llvm_context *ctx, LLVMValueRef rsrc,
                                          LLVMValueRef voffset, LLVMValueRef soffset,
@@ -399,6 +404,7 @@ struct ac_image_args {
    bool unorm : 1;
    bool level_zero : 1;
    bool d16 : 1;        /* data and return values are 16-bit, requires GFX8+ */
+   bool tfe : 1;
    unsigned attributes; /* additional call-site specific AC_FUNC_ATTRs */
 
    LLVMValueRef resource;
index 0555fe6..b6c768f 100644 (file)
@@ -1424,13 +1424,16 @@ static nir_deref_instr *get_tex_texture_deref(const nir_tex_instr *instr)
 static LLVMValueRef build_tex_intrinsic(struct ac_nir_context *ctx, const nir_tex_instr *instr,
                                         struct ac_image_args *args)
 {
+   assert((!args->tfe || !args->d16) && "unsupported");
+
    if (instr->sampler_dim == GLSL_SAMPLER_DIM_BUF) {
       unsigned mask = nir_ssa_def_components_read(&instr->dest.ssa);
 
       assert(instr->dest.is_ssa);
       return ac_build_buffer_load_format(&ctx->ac, args->resource, args->coords[0], ctx->ac.i32_0,
                                          util_last_bit(mask), 0, true,
-                                         instr->dest.ssa.bit_size == 16);
+                                         instr->dest.ssa.bit_size == 16,
+                                         args->tfe);
    }
 
    args->opcode = ac_image_sample;
@@ -2298,7 +2301,9 @@ static void get_image_coords(struct ac_nir_context *ctx, const nir_intrinsic_ins
    count = image_type_to_components_count(dim, is_array);
 
    if (is_ms && (instr->intrinsic == nir_intrinsic_image_deref_load ||
-                 instr->intrinsic == nir_intrinsic_bindless_image_load)) {
+                 instr->intrinsic == nir_intrinsic_bindless_image_load ||
+                 instr->intrinsic == nir_intrinsic_image_deref_sparse_load ||
+                 instr->intrinsic == nir_intrinsic_bindless_image_sparse_load)) {
       LLVMValueRef fmask_load_address[3];
 
       fmask_load_address[0] = LLVMBuildExtractElement(ctx->ac.builder, src0, masks[0], "");
@@ -2420,6 +2425,7 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx, const nir_intri
    struct ac_image_args args = {0};
 
    args.cache_policy = get_cache_policy(ctx, access, false, false);
+   args.tfe = instr->intrinsic == nir_intrinsic_image_deref_sparse_load;
 
    if (dim == GLSL_SAMPLER_DIM_BUF) {
       unsigned num_channels = util_last_bit(nir_ssa_def_components_read(&instr->dest.ssa));
@@ -2435,8 +2441,9 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx, const nir_intri
       bool can_speculate = access & ACCESS_CAN_REORDER;
       res = ac_build_buffer_load_format(&ctx->ac, rsrc, vindex, ctx->ac.i32_0, num_channels,
                                         args.cache_policy, can_speculate,
-                                        instr->dest.ssa.bit_size == 16);
-      res = ac_build_expand_to_vec4(&ctx->ac, res, num_channels);
+                                        instr->dest.ssa.bit_size == 16,
+                                        args.tfe);
+      res = ac_build_expand(&ctx->ac, res, num_channels, args.tfe ? 5 : 4);
 
       res = ac_trim_vector(&ctx->ac, res, instr->dest.ssa.num_components);
       res = ac_to_integer(&ctx->ac, res);
@@ -2459,12 +2466,20 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx, const nir_intri
    }
 
    if (instr->dest.ssa.bit_size == 64) {
+      LLVMValueRef code = NULL;
+      if (args.tfe) {
+         code = ac_llvm_extract_elem(&ctx->ac, res, 4);
+         res = ac_trim_vector(&ctx->ac, res, 4);
+      }
+
       res = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(ctx->ac.i64, 2), "");
       LLVMValueRef x = LLVMBuildExtractElement(ctx->ac.builder, res, ctx->ac.i32_0, "");
       LLVMValueRef w = LLVMBuildExtractElement(ctx->ac.builder, res, ctx->ac.i32_1, "");
 
-      LLVMValueRef values[4] = {x, ctx->ac.i64_0, ctx->ac.i64_0, w};
-      res = ac_build_gather_values(&ctx->ac, values, 4);
+      if (code)
+         code = LLVMBuildZExt(ctx->ac.builder, code, ctx->ac.i64, "");
+      LLVMValueRef values[5] = {x, ctx->ac.i64_0, ctx->ac.i64_0, w, code};
+      res = ac_build_gather_values(&ctx->ac, values, 4 + args.tfe);
    }
 
    return exit_waterfall(ctx, &wctx, res);
@@ -3583,6 +3598,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
       result = visit_image_load(ctx, instr, true);
       break;
    case nir_intrinsic_image_deref_load:
+   case nir_intrinsic_image_deref_sparse_load:
       result = visit_image_load(ctx, instr, false);
       break;
    case nir_intrinsic_bindless_image_store:
@@ -4441,9 +4457,16 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
 
    assert(instr->dest.is_ssa);
    args.d16 = instr->dest.ssa.bit_size == 16;
+   args.tfe = instr->is_sparse;
 
    result = build_tex_intrinsic(ctx, instr, &args);
 
+   LLVMValueRef code = NULL;
+   if (instr->is_sparse) {
+      code = ac_llvm_extract_elem(&ctx->ac, result, 4);
+      result = ac_trim_vector(&ctx->ac, result, 4);
+   }
+
    if (instr->op == nir_texop_query_levels)
       result =
          LLVMBuildExtractElement(ctx->ac.builder, result, LLVMConstInt(ctx->ac.i32, 3, false), "");
@@ -4462,9 +4485,12 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
       LLVMValueRef two = LLVMConstInt(ctx->ac.i32, 2, false);
       LLVMValueRef layers = LLVMBuildExtractElement(ctx->ac.builder, result, two, "");
       result = LLVMBuildInsertElement(ctx->ac.builder, result, layers, ctx->ac.i32_1, "");
-   } else if (instr->dest.ssa.num_components != 4)
+   } else if (nir_tex_instr_result_size(instr) != 4)
       result = ac_trim_vector(&ctx->ac, result, instr->dest.ssa.num_components);
 
+   if (instr->is_sparse)
+      result = ac_build_concat(&ctx->ac, result, code);
+
 write_result:
    if (result) {
       assert(instr->dest.is_ssa);
index 10e4668..4c94f2c 100644 (file)
@@ -460,7 +460,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
    if (key->opt.cs_indexed) {
       for (unsigned i = 0; i < 3; i++) {
          index[i] = ac_build_buffer_load_format(&ctx->ac, input_indexbuf, index[i], ctx->ac.i32_0,
-                                                1, 0, true, false);
+                                                1, 0, true, false, false);
          index[i] = ac_to_integer(&ctx->ac, index[i]);
       }
    }
index c280b58..19c011f 100644 (file)
@@ -158,7 +158,7 @@ static void load_input_vs(struct si_shader_context *ctx, unsigned input_index, L
    for (unsigned i = 0; i < num_fetches; ++i) {
       LLVMValueRef voffset = LLVMConstInt(ctx->ac.i32, fetch_stride * i, 0);
       fetches[i] = ac_build_buffer_load_format(&ctx->ac, vb_desc, vertex_index, voffset,
-                                               channels_per_fetch, 0, true, false);
+                                               channels_per_fetch, 0, true, false, false);
    }
 
    if (num_fetches == 1 && channels_per_fetch > 1) {