gallivm/nir: add support for mesh shader outputs.
authorDave Airlie <airlied@redhat.com>
Wed, 17 May 2023 01:43:23 +0000 (11:43 +1000)
committerDave Airlie <airlied@redhat.com>
Mon, 5 Jun 2023 19:01:47 +0000 (05:01 +1000)
mesh shaders can have vertex and primitive outputs, and act a bit
like TCS shaders, add the callback to allow the driver to decide
how to store these.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23066>

src/gallium/auxiliary/gallivm/lp_bld_nir.c
src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
src/gallium/auxiliary/gallivm/lp_bld_tgsi.h

index 9321a13..3432c3d 100644 (file)
@@ -1494,8 +1494,10 @@ visit_store_var(struct lp_build_nir_context *bld_base,
    if (var) {
       bool tcs_out = bld_base->shader->info.stage == MESA_SHADER_TESS_CTRL &&
          var->data.mode == nir_var_shader_out && !var->data.patch;
+      bool mesh_out = bld_base->shader->info.stage == MESA_SHADER_MESH &&
+         var->data.mode == nir_var_shader_out;
       get_deref_offset(bld_base, deref, false, NULL,
-                       tcs_out ? &indir_vertex_index : NULL,
+                       (tcs_out || mesh_out) ? &indir_vertex_index : NULL,
                        &const_index, &indir_index);
 
       /* Skip stores definitely outside of the array bounds
index 4f2b5f3..10c74a8 100644 (file)
@@ -650,30 +650,126 @@ static void emit_store_tcs_chan(struct lp_build_nir_context *bld_base,
       LLVMValueRef split_vals[2];
       LLVMValueRef swizzle_index_val2 = lp_build_const_int32(gallivm, swizzle + 1);
       emit_store_64bit_split(bld_base, chan_val, split_vals);
-      bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
-                                        indir_vertex_index ? true : false,
-                                        indir_vertex_index,
-                                        indir_index ? true : false,
-                                        attrib_index_val,
-                                        false, swizzle_index_val,
-                                        split_vals[0], exec_mask);
-      bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
-                                        indir_vertex_index ? true : false,
-                                        indir_vertex_index,
-                                        indir_index ? true : false,
-                                        attrib_index_val,
-                                        false, swizzle_index_val2,
-                                        split_vals[1], exec_mask);
+      if (bld->mesh_iface) {
+         bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index ? true : false,
+                                           attrib_index_val,
+                                           false, swizzle_index_val,
+                                           split_vals[0], exec_mask);
+         bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index ? true : false,
+                                           attrib_index_val,
+                                           false, swizzle_index_val2,
+                                           split_vals[1], exec_mask);
+      } else {
+         bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index ? true : false,
+                                           attrib_index_val,
+                                           false, swizzle_index_val,
+                                           split_vals[0], exec_mask);
+         bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index ? true : false,
+                                           attrib_index_val,
+                                           false, swizzle_index_val2,
+                                           split_vals[1], exec_mask);
+      }
+   } else {
+      chan_val = LLVMBuildBitCast(builder, chan_val, bld_base->base.vec_type, "");
+      if (bld->mesh_iface) {
+         bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index && !is_compact ? true : false,
+                                           attrib_index_val,
+                                           indir_index && is_compact ? true : false,
+                                           swizzle_index_val,
+                                           chan_val, exec_mask);
+      } else {
+         bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
+                                           indir_vertex_index ? true : false,
+                                           indir_vertex_index,
+                                           indir_index && !is_compact ? true : false,
+                                           attrib_index_val,
+                                           indir_index && is_compact ? true : false,
+                                           swizzle_index_val,
+                                           chan_val, exec_mask);
+      }
+   }
+}
+
+static void emit_store_mesh_chan(struct lp_build_nir_context *bld_base,
+                                 bool is_compact,
+                                 unsigned bit_size,
+                                 unsigned location,
+                                 unsigned const_index,
+                                 LLVMValueRef indir_vertex_index,
+                                 LLVMValueRef indir_index,
+                                 unsigned comp,
+                                 unsigned chan,
+                                 LLVMValueRef chan_val)
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
+   LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder;
+   unsigned swizzle = chan;
+   if (bit_size == 64) {
+      swizzle += const_index;
+      swizzle *= 2;
+      swizzle += comp;
+      if (swizzle >= 4) {
+         swizzle -= 4;
+         location++;
+      }
+   } else
+      swizzle += comp;
+   LLVMValueRef attrib_index_val;
+   LLVMValueRef swizzle_index_val = lp_build_const_int32(gallivm, swizzle);
+
+   if (indir_index) {
+      if (is_compact) {
+         swizzle_index_val = lp_build_add(&bld_base->uint_bld, indir_index, lp_build_const_int_vec(gallivm, bld_base->uint_bld.type, swizzle));
+         attrib_index_val = lp_build_const_int32(gallivm, location);
+      } else
+         attrib_index_val = lp_build_add(&bld_base->uint_bld, indir_index, lp_build_const_int_vec(gallivm, bld_base->uint_bld.type, location));
+   } else
+      attrib_index_val = lp_build_const_int32(gallivm, location + const_index);
+   LLVMValueRef exec_mask = mask_vec(bld_base);
+   if (bit_size == 64) {
+      LLVMValueRef split_vals[2];
+      LLVMValueRef swizzle_index_val2 = lp_build_const_int32(gallivm, swizzle + 1);
+      emit_store_64bit_split(bld_base, chan_val, split_vals);
+      bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                         indir_vertex_index ? true : false,
+                                         indir_vertex_index,
+                                         indir_index ? true : false,
+                                         attrib_index_val,
+                                         false, swizzle_index_val,
+                                         split_vals[0], exec_mask);
+      bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                         indir_vertex_index ? true : false,
+                                         indir_vertex_index,
+                                         indir_index ? true : false,
+                                         attrib_index_val,
+                                         false, swizzle_index_val2,
+                                         split_vals[1], exec_mask);
    } else {
       chan_val = LLVMBuildBitCast(builder, chan_val, bld_base->base.vec_type, "");
-      bld->tcs_iface->emit_store_output(bld->tcs_iface, &bld_base->base, 0,
-                                        indir_vertex_index ? true : false,
-                                        indir_vertex_index,
-                                        indir_index && !is_compact ? true : false,
-                                        attrib_index_val,
-                                        indir_index && is_compact ? true : false,
-                                        swizzle_index_val,
-                                        chan_val, exec_mask);
+      bld->mesh_iface->emit_store_output(bld->mesh_iface, &bld_base->base, 0,
+                                         indir_vertex_index ? true : false,
+                                         indir_vertex_index,
+                                         indir_index && !is_compact ? true : false,
+                                         attrib_index_val,
+                                         indir_index && is_compact ? true : false,
+                                         swizzle_index_val,
+                                         chan_val, exec_mask);
    }
 }
 
@@ -710,7 +806,9 @@ static void emit_store_var(struct lp_build_nir_context *bld_base,
       for (unsigned chan = 0; chan < num_components; chan++) {
          if (writemask & (1u << chan)) {
             LLVMValueRef chan_val = (num_components == 1) ? dst : LLVMBuildExtractValue(builder, dst, chan, "");
-            if (bld->tcs_iface) {
+            if (bld->mesh_iface) {
+               emit_store_mesh_chan(bld_base, var->data.compact, bit_size, location, const_index, indir_vertex_index, indir_index, comp, chan, chan_val);
+            } else if (bld->tcs_iface) {
                emit_store_tcs_chan(bld_base, var->data.compact, bit_size, location, const_index, indir_vertex_index, indir_index, comp, chan, chan_val);
             } else
                emit_store_chan(bld_base, deref_mode, bit_size, location + const_index, comp, chan, chan_val);
index 21dbc4e..b9fe368 100644 (file)
@@ -508,6 +508,17 @@ struct lp_build_tes_iface
 
 struct lp_build_mesh_iface
 {
+   void (*emit_store_output)(const struct lp_build_mesh_iface *mesh_iface,
+                             struct lp_build_context * bld,
+                             unsigned name,
+                             boolean is_vindex_indirect,
+                             LLVMValueRef vertex_index,
+                             boolean is_aindex_indirect,
+                             LLVMValueRef attrib_index,
+                             boolean is_sindex_indirect,
+                             LLVMValueRef swizzle_index,
+                             LLVMValueRef value,
+                             LLVMValueRef mask_vec);
    void (*emit_vertex_and_primitive_count)(const struct lp_build_mesh_iface *mesh_iface,
                                            struct lp_build_context *bld,
                                            LLVMValueRef vertices_count,