zink: implement shared load/store nir ops in ntv
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Wed, 12 Aug 2020 19:57:17 +0000 (15:57 -0400)
committerMarge Bot <eric+marge@anholt.net>
Wed, 10 Feb 2021 00:19:38 +0000 (00:19 +0000)
these access the data in the shared block variable at an offset

Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/8781>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index d1bb8ba..dab6dc7 100644 (file)
@@ -2082,6 +2082,73 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    spirv_builder_emit_store(&ctx->builder, ptr, result);
 }
 
+static void
+emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint);
+   unsigned num_components = nir_dest_num_components(intr->dest);
+   unsigned bit_size = nir_dest_bit_size(intr->dest);
+   bool qword = bit_size == 64;
+   SpvId uint_type = get_uvec_type(ctx, 32, 1);
+   SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
+                                               SpvStorageClassWorkgroup,
+                                               uint_type);
+   SpvId offset = emit_binop(ctx, SpvOpUDiv, uint_type, get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, 4));
+   SpvId constituents[num_components];
+   /* need to convert array -> vec */
+   for (unsigned i = 0; i < num_components; i++) {
+      SpvId parts[2];
+      for (unsigned j = 0; j < 1 + !!qword; j++) {
+         SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
+                                                        ctx->shared_block_var, &offset, 1);
+         parts[j] = spirv_builder_emit_load(&ctx->builder, uint_type, member);
+         offset = emit_binop(ctx, SpvOpIAdd, uint_type, offset, emit_uint_const(ctx, 32, 1));
+      }
+      if (qword)
+         constituents[i] = spirv_builder_emit_composite_construct(&ctx->builder, get_uvec_type(ctx, 64, 1), parts, 2);
+      else
+         constituents[i] = parts[0];
+   }
+   SpvId result;
+   if (num_components > 1)
+      result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type, constituents, num_components);
+   else
+      result = bitcast_to_uvec(ctx, constituents[0], bit_size, num_components);
+   store_dest(ctx, &intr->dest, result, nir_type_uint);
+}
+
+static void
+emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   SpvId src = get_src(ctx, &intr->src[0]);
+   bool qword = nir_src_bit_size(intr->src[0]) == 64;
+
+   unsigned num_writes = util_bitcount(nir_intrinsic_write_mask(intr));
+   unsigned wrmask = nir_intrinsic_write_mask(intr);
+   /* this is a partial write, so we have to loop and do a per-component write */
+   SpvId uint_type = get_uvec_type(ctx, 32, 1);
+   SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
+                                               SpvStorageClassWorkgroup,
+                                               uint_type);
+   SpvId offset = emit_binop(ctx, SpvOpUDiv, uint_type, get_src(ctx, &intr->src[1]), emit_uint_const(ctx, 32, 4));
+
+   for (unsigned i = 0; num_writes; i++) {
+      if ((wrmask >> i) & 1) {
+         for (unsigned j = 0; j < 1 + !!qword; j++) {
+            unsigned comp = ((1 + !!qword) * i) + j;
+            SpvId shared_offset = emit_binop(ctx, SpvOpIAdd, uint_type, offset, emit_uint_const(ctx, 32, comp));
+            SpvId val = src;
+            if (nir_src_num_components(intr->src[0]) != 1 || qword)
+               val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &comp, 1);
+            SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
+                                                           ctx->shared_block_var, &shared_offset, 1);
+            spirv_builder_emit_store(&ctx->builder, member, val);
+         }
+         num_writes--;
+      }
+   }
+}
+
 /* FIXME: this is currently VERY specific to injected TCS usage */
 static void
 emit_load_push_const(struct ntv_context *ctx, nir_intrinsic_instr *intr)
@@ -2690,6 +2757,14 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       emit_load_uint_input(ctx, intr, &ctx->local_invocation_index_var, "gl_LocalInvocationIndex", SpvBuiltInLocalInvocationIndex);
       break;
 
+   case nir_intrinsic_load_shared:
+      emit_load_shared(ctx, intr);
+      break;
+
+   case nir_intrinsic_store_shared:
+      emit_store_shared(ctx, intr);
+      break;
+
    default:
       fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
               nir_intrinsic_infos[intr->intrinsic].name);