draw: add support for per primitive aos emission
authorDave Airlie <airlied@redhat.com>
Wed, 17 May 2023 01:22:23 +0000 (11:22 +1000)
committerDave Airlie <airlied@redhat.com>
Mon, 5 Jun 2023 19:01:46 +0000 (05:01 +1000)
This add support to the aos emit code so that mesh shaders
can use it for per prim outputs.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23066>

src/gallium/auxiliary/draw/draw_llvm.c
src/gallium/auxiliary/draw/draw_llvm.h

index bd05de2..0b25bda 100644 (file)
@@ -773,6 +773,7 @@ fetch_vector(struct gallivm_state *gallivm,
 
 static void
 store_aos(struct gallivm_state *gallivm,
+          bool is_per_prim,
           LLVMTypeRef io_type,
           LLVMValueRef io_ptr,
           LLVMValueRef index,
@@ -780,19 +781,30 @@ store_aos(struct gallivm_state *gallivm,
 {
    LLVMTypeRef data_ptr_type = LLVMPointerType(lp_build_vec_type(gallivm, lp_float32_vec4_type()), 0);
    LLVMBuilderRef builder = gallivm->builder;
-   LLVMValueRef data_ptr = lp_jit_vertex_header_data(gallivm, io_type, io_ptr);
-   LLVMTypeRef data_type = LLVMStructGetTypeAtIndex(io_type, LP_JIT_VERTEX_HEADER_DATA);
+   LLVMValueRef data_ptr;
+   LLVMTypeRef data_type;
    LLVMValueRef indices[3];
 
    indices[0] = lp_build_const_int32(gallivm, 0);
    indices[1] = index;
    indices[2] = lp_build_const_int32(gallivm, 0);
 
+   if (!is_per_prim) {
+      data_ptr = lp_jit_vertex_header_data(gallivm, io_type, io_ptr);
+      data_type = LLVMStructGetTypeAtIndex(io_type, LP_JIT_VERTEX_HEADER_DATA);
+   } else {
+      data_ptr = io_ptr;
+      data_type = io_type;
+   }
+
    data_ptr = LLVMBuildGEP2(builder, data_type, data_ptr, indices, 3, "");
    data_ptr = LLVMBuildPointerCast(builder, data_ptr, data_ptr_type, "");
 
 #if DEBUG_STORE
-   lp_build_printf(gallivm, "    ---- %p storing attribute %d (io = %p)\n", data_ptr, index, io_ptr);
+   if (is_per_prim)
+      lp_build_printf(gallivm, "    ---- %p storing prim attribute %d (io = %p)\n", data_ptr, index, io_ptr);
+   else
+      lp_build_printf(gallivm, "    ---- %p storing attribute %d (io = %p)\n", data_ptr, index, io_ptr);
 #endif
 
    /* Unaligned store due to the vertex header */
@@ -853,17 +865,16 @@ adjust_mask(struct gallivm_state *gallivm,
 }
 
 
-static void
-store_aos_array(struct gallivm_state *gallivm,
-                struct lp_type soa_type,
-                LLVMTypeRef io_type,
-                LLVMValueRef io_ptr,
-                LLVMValueRef *indices,
-                LLVMValueRef* aos,
-                int attrib,
-                int num_outputs,
-                LLVMValueRef clipmask,
-                boolean need_edgeflag)
+void
+draw_store_aos_array(struct gallivm_state *gallivm,
+                     struct lp_type soa_type,
+                     LLVMTypeRef io_type,
+                     LLVMValueRef io_ptr,
+                     LLVMValueRef *indices,
+                     LLVMValueRef* aos,
+                     int attrib,
+                     LLVMValueRef clipmask,
+                     boolean need_edgeflag, bool is_per_prim)
 {
    LLVMBuilderRef builder = gallivm->builder;
    LLVMValueRef attr_index = lp_build_const_int32(gallivm, attrib);
@@ -884,7 +895,7 @@ store_aos_array(struct gallivm_state *gallivm,
       io_ptrs[i] = LLVMBuildGEP2(builder, io_type, io_ptr, &inds[i], 1, "");
    }
 
-   if (attrib == 0) {
+   if (attrib == 0 && !is_per_prim) {
       /* store vertex header for each of the n vertices */
       LLVMValueRef val, cliptmp;
       int vertex_id_pad_edgeflag;
@@ -899,13 +910,20 @@ store_aos_array(struct gallivm_state *gallivm,
       } else {
          vertex_id_pad_edgeflag = (0xffff << 16);
       }
-      val = lp_build_const_int_vec(gallivm, lp_int_type(soa_type),
-                                   vertex_id_pad_edgeflag);
+      if (vector_length == 1)
+         val = lp_build_const_int32(gallivm, vertex_id_pad_edgeflag);
+      else
+         val = lp_build_const_int_vec(gallivm, lp_int_type(soa_type),
+                                      vertex_id_pad_edgeflag);
+
       /* OR with the clipmask */
       cliptmp = LLVMBuildOr(builder, val, clipmask, "");
       for (unsigned i = 0; i < vector_length; i++) {
          LLVMValueRef id_ptr = lp_jit_vertex_header_id(gallivm, io_type, io_ptrs[i]);
-         val = LLVMBuildExtractElement(builder, cliptmp, linear_inds[i], "");
+         if (vector_length > 1)
+            val = LLVMBuildExtractElement(builder, cliptmp, linear_inds[i], "");
+         else
+            val = cliptmp;
          val = adjust_mask(gallivm, val);
 #if DEBUG_STORE
          lp_build_printf(gallivm, "io = %p, index %d, clipmask = %x\n",
@@ -917,7 +935,7 @@ store_aos_array(struct gallivm_state *gallivm,
 
    /* store for each of the n vertices */
    for (int i = 0; i < vector_length; i++) {
-      store_aos(gallivm, io_type, io_ptrs[i], attr_index, aos[i]);
+      store_aos(gallivm, is_per_prim, io_type, io_ptrs[i], attr_index, aos[i]);
    }
 }
 
@@ -981,16 +999,15 @@ convert_to_aos(struct gallivm_state *gallivm,
          }
       }
 
-      store_aos_array(gallivm,
-                      soa_type,
-                      io_type,
-                      io,
-                      indices,
-                      aos,
-                      attrib,
-                      num_outputs,
-                      clipmask,
-                      need_edgeflag);
+      draw_store_aos_array(gallivm,
+                           soa_type,
+                           io_type,
+                           io,
+                           indices,
+                           aos,
+                           attrib,
+                           clipmask,
+                           need_edgeflag, false);
    }
 #if DEBUG_STORE
    lp_build_printf(gallivm, "   # storing end\n");
index ef00bda..976dfcb 100644 (file)
@@ -680,4 +680,17 @@ draw_llvm_set_mapped_image(struct draw_context *draw,
                            uint32_t img_stride,
                            uint32_t num_samples,
                            uint32_t sample_stride);
+
+void
+draw_store_aos_array(struct gallivm_state *gallivm,
+                     struct lp_type soa_type,
+                     LLVMTypeRef io_type,
+                     LLVMValueRef io_ptr,
+                     LLVMValueRef *indices,
+                     LLVMValueRef* aos,
+                     int attrib,
+                     LLVMValueRef clipmask,
+                     boolean need_edgeflag,
+                     bool per_prim);
+
 #endif