gallivm: use masked intrinsics for global and scratch access.
authorDave Airlie <airlied@redhat.com>
Mon, 14 Nov 2022 08:00:10 +0000 (18:00 +1000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 16 Nov 2022 23:31:54 +0000 (23:31 +0000)
This seems to improve luxmark scores for me on the luxball scene
from numbers in the 4-500 range to 5-700 range.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Tested-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19736>

src/gallium/auxiliary/gallivm/lp_bld_gather.c
src/gallium/auxiliary/gallivm/lp_bld_gather.h
src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c

index 2f25068..b93251b 100644 (file)
@@ -598,3 +598,50 @@ lp_build_gather_values(struct gallivm_state * gallivm,
    }
    return vec;
 }
+
+LLVMValueRef
+lp_build_masked_gather(struct gallivm_state *gallivm,
+                       unsigned length,
+                       unsigned bit_size,
+                       LLVMTypeRef vec_type,
+                       LLVMValueRef offset_ptr,
+                       LLVMValueRef exec_mask)
+{
+   LLVMBuilderRef builder = gallivm->builder;
+   LLVMValueRef args[4];
+   char intrin_name[64];
+
+   snprintf(intrin_name, 64, "llvm.masked.gather.v%ui%u.v%up0i%u",
+            length, bit_size, length, bit_size);
+   args[0] = offset_ptr;
+   args[1] = lp_build_const_int32(gallivm, bit_size / 8);
+   args[2] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask,
+                           LLVMConstNull(LLVMTypeOf(exec_mask)), "");
+   args[3] = LLVMConstNull(vec_type);
+   return lp_build_intrinsic(builder, intrin_name, vec_type,
+                             args, 4, 0);
+
+}
+
+void
+lp_build_masked_scatter(struct gallivm_state *gallivm,
+                        unsigned length,
+                        unsigned bit_size,
+                        LLVMValueRef offset_ptr,
+                        LLVMValueRef value_vec,
+                        LLVMValueRef exec_mask)
+{
+   LLVMBuilderRef builder = gallivm->builder;
+   LLVMValueRef args[4];
+   char intrin_name[64];
+
+   snprintf(intrin_name, 64, "llvm.masked.scatter.v%ui%u.v%up0i%u",
+            length, bit_size, length, bit_size);
+   args[0] = value_vec;
+   args[1] = offset_ptr;
+   args[2] = lp_build_const_int32(gallivm, bit_size / 8);
+   args[3] = LLVMBuildICmp(builder, LLVMIntNE, exec_mask,
+                           LLVMConstNull(LLVMTypeOf(exec_mask)), "");
+   lp_build_intrinsic(builder, intrin_name, LLVMVoidTypeInContext(gallivm->context),
+                      args, 4, 0);
+}
index 7930864..5fabed9 100644 (file)
@@ -66,4 +66,20 @@ lp_build_gather_values(struct gallivm_state * gallivm,
                        LLVMValueRef * values,
                        unsigned value_count);
 
+LLVMValueRef
+lp_build_masked_gather(struct gallivm_state *gallivm,
+                       unsigned length,
+                       unsigned bit_size,
+                       LLVMTypeRef vec_type,
+                       LLVMValueRef offset_ptr,
+                       LLVMValueRef exec_mask);
+
+void
+lp_build_masked_scatter(struct gallivm_state *gallivm,
+                        unsigned length,
+                        unsigned bit_size,
+                        LLVMValueRef offset_ptr,
+                        LLVMValueRef value_vec,
+                        LLVMValueRef exec_mask);
+
 #endif /* LP_BLD_GATHER_H_ */
index 443bf4f..57c953a 100644 (file)
@@ -822,6 +822,41 @@ static LLVMValueRef global_addr_to_ptr(struct gallivm_state *gallivm, LLVMValueR
    return addr_ptr;
 }
 
+static LLVMValueRef global_addr_to_ptr_vec(struct gallivm_state *gallivm, LLVMValueRef addr_ptr, unsigned length, unsigned bit_size)
+{
+   LLVMBuilderRef builder = gallivm->builder;
+   switch (bit_size) {
+   case 8:
+      addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), length), "");
+      break;
+   case 16:
+      addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt16TypeInContext(gallivm->context), 0), length), "");
+      break;
+   case 32:
+   default:
+      addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0), length), "");
+      break;
+   case 64:
+      addr_ptr = LLVMBuildIntToPtr(builder, addr_ptr, LLVMVectorType(LLVMPointerType(LLVMInt64TypeInContext(gallivm->context), 0), length), "");
+      break;
+   }
+   return addr_ptr;
+}
+
+static LLVMValueRef lp_vec_add_offset_ptr(struct lp_build_nir_context *bld_base,
+                                          unsigned bit_size,
+                                          LLVMValueRef ptr,
+                                          LLVMValueRef offset)
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+   struct lp_build_context *uint_bld = &bld_base->uint_bld;
+   LLVMValueRef result = LLVMBuildPtrToInt(builder, ptr, bld_base->uint64_bld.vec_type, "");
+   offset = LLVMBuildZExt(builder, offset, bld_base->uint64_bld.vec_type, "");
+   result = LLVMBuildAdd(builder, offset, result, "");
+   return global_addr_to_ptr_vec(gallivm, result, uint_bld->type.length, bit_size);
+}
+
 static void emit_load_global(struct lp_build_nir_context *bld_base,
                              unsigned nc,
                              unsigned bit_size,
@@ -855,30 +890,14 @@ static void emit_load_global(struct lp_build_nir_context *bld_base,
    }
 
    for (unsigned c = 0; c < nc; c++) {
-      LLVMValueRef result = lp_build_alloca(gallivm, res_bld->vec_type, "");
-      struct lp_build_loop_state loop_state;
-      lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
-
-      struct lp_build_if_state ifthen;
-      LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
-      cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
-      lp_build_if(&ifthen, gallivm, cond);
+      LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8));
 
-      LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr,
-                                                      loop_state.counter, "");
-      addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size);
-
-      LLVMValueRef value_ptr = lp_build_pointer_get2(builder, res_bld->elem_type,
-                                                     addr_ptr, lp_build_const_int32(gallivm, c));
-
-      LLVMValueRef temp_res;
-      temp_res = LLVMBuildLoad2(builder, res_bld->vec_type, result, "");
-      temp_res = LLVMBuildInsertElement(builder, temp_res, value_ptr, loop_state.counter, "");
-      LLVMBuildStore(builder, temp_res, result);
-      lp_build_endif(&ifthen);
-      lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
-                             NULL, LLVMIntUGE);
-      outval[c] = LLVMBuildLoad2(builder, res_bld->vec_type, result, "");
+      outval[c] = lp_build_masked_gather(gallivm, res_bld->type.length,
+                                         bit_size,
+                                         res_bld->vec_type,
+                                         lp_vec_add_offset_ptr(bld_base, bit_size, addr, chan_offset),
+                                         exec_mask);
+      outval[c] = LLVMBuildBitCast(builder, outval[c], res_bld->vec_type, "");
    }
 }
 
@@ -898,40 +917,14 @@ static void emit_store_global(struct lp_build_nir_context *bld_base,
       if (!(writemask & (1u << c)))
          continue;
       LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, "");
+      LLVMValueRef chan_offset = lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8));
 
-      struct lp_build_loop_state loop_state;
-      lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
-      LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val,
-                                                       loop_state.counter, "");
-
-      LLVMValueRef addr_ptr = LLVMBuildExtractElement(gallivm->builder, addr,
-                                                      loop_state.counter, "");
-      addr_ptr = global_addr_to_ptr(gallivm, addr_ptr, bit_size);
-      switch (bit_size) {
-      case 8:
-         value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt8TypeInContext(gallivm->context), "");
-         break;
-      case 16:
-         value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt16TypeInContext(gallivm->context), "");
-         break;
-      case 32:
-         value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt32TypeInContext(gallivm->context), "");
-         break;
-      case 64:
-         value_ptr = LLVMBuildBitCast(builder, value_ptr, LLVMInt64TypeInContext(gallivm->context), "");
-         break;
-      default:
-         break;
-      }
-      struct lp_build_if_state ifthen;
-
-      LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
-      cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
-      lp_build_if(&ifthen, gallivm, cond);
-      lp_build_pointer_set(builder, addr_ptr, lp_build_const_int32(gallivm, c), value_ptr);
-      lp_build_endif(&ifthen);
-      lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
-                             NULL, LLVMIntUGE);
+      struct lp_build_context *out_bld = get_int_bld(bld_base, false, bit_size);
+      val = LLVMBuildBitCast(builder, val, out_bld->vec_type, "");
+      lp_build_masked_scatter(gallivm, out_bld->type.length, bit_size,
+                              lp_vec_add_offset_ptr(bld_base, bit_size,
+                                                    addr, chan_offset),
+                              val, exec_mask);
    }
 }
 
@@ -2616,46 +2609,25 @@ emit_load_scratch(struct lp_build_nir_context *bld_base,
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
    struct lp_build_context *load_bld;
-   LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);;
-   uint32_t shift_val = bit_size_to_shift_size(bit_size);
+   LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);
    LLVMValueRef exec_mask = mask_vec(bld_base);
-
+   LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm,
+                                                     LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length),
+                                                     bld->scratch_ptr);
    load_bld = get_int_bld(bld_base, true, bit_size);
 
    offset = lp_build_add(uint_bld, offset, thread_offsets);
-   offset = lp_build_shr_imm(uint_bld, offset, shift_val);
-   for (unsigned c = 0; c < nc; c++) {
-      LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
-
-      LLVMValueRef result = lp_build_alloca(gallivm, load_bld->vec_type, "");
-      struct lp_build_loop_state loop_state;
-      lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
-
-      struct lp_build_if_state ifthen;
-      LLVMValueRef cond, temp_res;
 
-      loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index,
-                                           loop_state.counter, "");
-      cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
-      cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
-
-      lp_build_if(&ifthen, gallivm, cond);
-      LLVMValueRef scalar;
-      LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(load_bld->elem_type, 0), "");
-      scalar = lp_build_pointer_get2(builder, load_bld->elem_type, ptr2, loop_index);
+   for (unsigned c = 0; c < nc; c++) {
+      LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)));
 
-      temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, "");
-      temp_res = LLVMBuildInsertElement(builder, temp_res, scalar, loop_state.counter, "");
-      LLVMBuildStore(builder, temp_res, result);
-      lp_build_else(&ifthen);
-      temp_res = LLVMBuildLoad2(builder, load_bld->vec_type, result, "");
-      LLVMValueRef zero = lp_build_zero_bits(gallivm, bit_size, false);
-      temp_res = LLVMBuildInsertElement(builder, temp_res, zero, loop_state.counter, "");
-      LLVMBuildStore(builder, temp_res, result);
-      lp_build_endif(&ifthen);
-      lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
-                                NULL, LLVMIntUGE);
-      outval[c] = LLVMBuildLoad2(gallivm->builder, load_bld->vec_type, result, "");
+      outval[c] = lp_build_masked_gather(gallivm, load_bld->type.length, bit_size,
+                                         load_bld->vec_type,
+                                         lp_vec_add_offset_ptr(bld_base, bit_size,
+                                                               scratch_ptr_vec,
+                                                               chan_offset),
+                                         exec_mask);
+      outval[c] = LLVMBuildBitCast(builder, outval[c], load_bld->vec_type, "");
    }
 }
 
@@ -2670,43 +2642,28 @@ emit_store_scratch(struct lp_build_nir_context *bld_base,
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
    struct lp_build_context *store_bld;
-   LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);;
-   uint32_t shift_val = bit_size_to_shift_size(bit_size);
+   LLVMValueRef thread_offsets = get_scratch_thread_offsets(gallivm, uint_bld->type, bld->scratch_size);
+   LLVMValueRef scratch_ptr_vec = lp_build_broadcast(gallivm,
+                                                     LLVMVectorType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), uint_bld->type.length),
+                                                     bld->scratch_ptr);
    store_bld = get_int_bld(bld_base, true, bit_size);
 
    LLVMValueRef exec_mask = mask_vec(bld_base);
    offset = lp_build_add(uint_bld, offset, thread_offsets);
-   offset = lp_build_shr_imm(uint_bld, offset, shift_val);
 
    for (unsigned c = 0; c < nc; c++) {
       if (!(writemask & (1u << c)))
          continue;
       LLVMValueRef val = (nc == 1) ? dst : LLVMBuildExtractValue(builder, dst, c, "");
-      LLVMValueRef loop_index = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c));
 
-      struct lp_build_loop_state loop_state;
-      lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
+      LLVMValueRef chan_offset = lp_build_add(uint_bld, offset, lp_build_const_int_vec(gallivm, uint_bld->type, c * (bit_size / 8)));
 
-      LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, val,
-                                                       loop_state.counter, "");
-      value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, store_bld->elem_type, "");
-
-      struct lp_build_if_state ifthen;
-      LLVMValueRef cond;
+      val = LLVMBuildBitCast(builder, val, store_bld->vec_type, "");
 
-      loop_index = LLVMBuildExtractElement(gallivm->builder, loop_index,
-                                                        loop_state.counter, "");
-
-      cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, exec_mask, uint_bld->zero, "");
-      cond = LLVMBuildExtractElement(gallivm->builder, cond, loop_state.counter, "");
-      lp_build_if(&ifthen, gallivm, cond);
-
-      LLVMValueRef ptr2 = LLVMBuildBitCast(builder, bld->scratch_ptr, LLVMPointerType(store_bld->elem_type, 0), "");
-      lp_build_pointer_set(builder, ptr2, loop_index, value_ptr);
-
-      lp_build_endif(&ifthen);
-      lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, uint_bld->type.length),
-                             NULL, LLVMIntUGE);
+      lp_build_masked_scatter(gallivm, store_bld->type.length, bit_size,
+                              lp_vec_add_offset_ptr(bld_base, bit_size,
+                                                    scratch_ptr_vec, chan_offset),
+                              val, exec_mask);
    }
 }