radeonsi: simplify si_build_wrapper_function
authorQiang Yu <yuq825@gmail.com>
Tue, 11 Jul 2023 09:56:29 +0000 (17:56 +0800)
committerMarge Bot <emma+marge@anholt.net>
Mon, 24 Jul 2023 01:49:21 +0000 (01:49 +0000)
We only need it to merge LS/HS or ES/GS now, prolog and epilog have
been lowered in nir already. So we just need to handle two parts and
they are sure to be first and second stage of a merged shader.

This also remove the needs SGPRs must be before VGPRs, which is required
by following commits to move some SGPRs after VGPRs.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24204>

src/gallium/drivers/radeonsi/si_shader_internal.h
src/gallium/drivers/radeonsi/si_shader_llvm.c

index febea08..0f2af78 100644 (file)
@@ -181,11 +181,6 @@ LLVMValueRef si_insert_input_ptr(struct si_shader_context *ctx, LLVMValueRef ret
 LLVMValueRef si_prolog_get_internal_bindings(struct si_shader_context *ctx);
 LLVMValueRef si_unpack_param(struct si_shader_context *ctx, struct ac_arg param, unsigned rshift,
                              unsigned bitwidth);
-void si_build_wrapper_function(struct si_shader_context *ctx, struct ac_llvm_pointer *parts,
-                               unsigned num_parts, unsigned main_part,
-                               unsigned next_shader_first_part,
-                               enum ac_arg_type *main_arg_types,
-                               bool same_thread_count);
 bool si_llvm_compile_shader(struct si_screen *sscreen, struct ac_llvm_compiler *compiler,
                             struct si_shader *shader, struct si_shader_args *args,
                             struct util_debug_callback *debug, struct nir_shader *nir);
index a52b12c..dd04b1a 100644 (file)
@@ -388,282 +388,117 @@ static void si_llvm_declare_compute_memory(struct si_shader_context *ctx)
 }
 
 /**
- * Given a list of shader part functions, build a wrapper function that
+ * Given two parts (LS/HS or ES/GS) of a merged shader, build a wrapper function that
  * runs them in sequence to form a monolithic shader.
  */
-void si_build_wrapper_function(struct si_shader_context *ctx, struct ac_llvm_pointer *parts,
-                               unsigned num_parts, unsigned main_part,
-                               unsigned next_shader_first_part,
-                               enum ac_arg_type *main_arg_types, bool same_thread_count)
+static void si_build_wrapper_function(struct si_shader_context *ctx,
+                                      struct ac_llvm_pointer parts[2],
+                                      bool same_thread_count)
 {
    LLVMBuilderRef builder = ctx->ac.builder;
-   /* PS epilog has one arg per color component; gfx9 merged shader
-    * prologs need to forward 40 SGPRs.
-    */
-   LLVMValueRef initial[AC_MAX_ARGS], out[AC_MAX_ARGS];
-   LLVMTypeRef function_type;
-   unsigned num_first_params;
-   unsigned num_out, initial_num_out;
-   ASSERTED unsigned num_out_sgpr;         /* used in debug checks */
-   ASSERTED unsigned initial_num_out_sgpr; /* used in debug checks */
-   unsigned num_sgprs, num_vgprs;
-   unsigned gprs;
-
-   memset(ctx->args, 0, sizeof(*ctx->args));
-
-   for (unsigned i = 0; i < num_parts; ++i) {
+
+   for (unsigned i = 0; i < 2; ++i) {
       ac_add_function_attr(ctx->ac.context, parts[i].value, -1, "alwaysinline");
       LLVMSetLinkage(parts[i].value, LLVMPrivateLinkage);
    }
 
-   /* The parameters of the wrapper function correspond to those of the
-    * first part in terms of SGPRs and VGPRs, but we use the types of the
-    * main part to get the right types. This is relevant for the
-    * dereferenceable attribute on descriptor table pointers.
-    */
-   num_sgprs = 0;
-   num_vgprs = 0;
-
-   function_type = parts[0].pointee_type;
-   num_first_params = LLVMCountParamTypes(function_type);
-
-   for (unsigned i = 0; i < num_first_params; ++i) {
-      LLVMValueRef param = LLVMGetParam(parts[0].value, i);
-
-      if (ac_is_sgpr_param(param)) {
-         assert(num_vgprs == 0);
-         num_sgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
-      } else {
-         num_vgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
-      }
-   }
-
-   gprs = 0;
-   while (gprs < num_sgprs + num_vgprs) {
-      LLVMValueRef param = LLVMGetParam(parts[main_part].value, ctx->args->ac.arg_count);
-      LLVMTypeRef type = LLVMTypeOf(param);
-      unsigned size = ac_get_type_size(type) / 4;
-      enum ac_arg_type arg_type = main_arg_types[ctx->args->ac.arg_count];
-      assert(arg_type != AC_ARG_INVALID);
-
-      ac_add_arg(&ctx->args->ac, gprs < num_sgprs ? AC_ARG_SGPR : AC_ARG_VGPR, size, arg_type, NULL);
-
-      assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
-      assert(gprs + size <= num_sgprs + num_vgprs &&
-             (gprs >= num_sgprs || gprs + size <= num_sgprs));
-
-      gprs += size;
-   }
-
-   /* Prepare the return type. */
-   unsigned num_returns = 0;
-   LLVMTypeRef returns[AC_MAX_ARGS], last_func_type, return_type;
+   si_llvm_create_func(ctx, "wrapper", NULL, 0, si_get_max_workgroup_size(ctx->shader));
 
-   last_func_type = parts[num_parts - 1].pointee_type;
-   return_type = LLVMGetReturnType(last_func_type);
-
-   switch (LLVMGetTypeKind(return_type)) {
-   case LLVMStructTypeKind:
-      num_returns = LLVMCountStructElementTypes(return_type);
-      assert(num_returns <= ARRAY_SIZE(returns));
-      LLVMGetStructElementTypes(return_type, returns);
-      break;
-   case LLVMVoidTypeKind:
-      break;
-   default:
-      unreachable("unexpected type");
-   }
-
-   si_llvm_create_func(ctx, "wrapper", returns, num_returns,
-                       si_get_max_workgroup_size(ctx->shader));
-
-   if (si_is_merged_shader(ctx->shader) && !same_thread_count)
+   if (same_thread_count) {
+      si_init_exec_from_input(ctx, ctx->args->ac.merged_wave_info, 0);
+   } else {
       ac_init_exec_full_mask(&ctx->ac);
 
-   /* Record the arguments of the function as if they were an output of
-    * a previous part.
-    */
-   num_out = 0;
-   num_out_sgpr = 0;
-
-   for (unsigned i = 0; i < ctx->args->ac.arg_count; ++i) {
-      LLVMValueRef param = LLVMGetParam(ctx->main_fn.value, i);
-      LLVMTypeRef param_type = LLVMTypeOf(param);
-      LLVMTypeRef out_type = ctx->args->ac.args[i].file == AC_ARG_SGPR ? ctx->ac.i32 : ctx->ac.f32;
-      unsigned size = ac_get_type_size(param_type) / 4;
-
-      if (size == 1) {
-         if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-            param = LLVMBuildPtrToInt(builder, param, ctx->ac.i32, "");
-            param_type = ctx->ac.i32;
-         }
-
-         if (param_type != out_type)
-            param = LLVMBuildBitCast(builder, param, out_type, "");
-         out[num_out++] = param;
-      } else {
-         LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
-
-         if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-            param = LLVMBuildPtrToInt(builder, param, ctx->ac.i64, "");
-            param_type = ctx->ac.i64;
-         }
-
-         if (param_type != vector_type)
-            param = LLVMBuildBitCast(builder, param, vector_type, "");
-
-         for (unsigned j = 0; j < size; ++j)
-            out[num_out++] =
-               LLVMBuildExtractElement(builder, param, LLVMConstInt(ctx->ac.i32, j, 0), "");
-      }
+      LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
+      count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
 
-      if (ctx->args->ac.args[i].file == AC_ARG_SGPR)
-         num_out_sgpr = num_out;
+      LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
+      ac_build_ifcc(&ctx->ac, ena, 6506);
    }
 
-   memcpy(initial, out, sizeof(out));
-   initial_num_out = num_out;
-   initial_num_out_sgpr = num_out_sgpr;
-
-   /* Now chain the parts. */
-   LLVMValueRef ret = NULL;
-   for (unsigned part = 0; part < num_parts; ++part) {
-      LLVMValueRef in[AC_MAX_ARGS];
-      LLVMTypeRef ret_type;
-      unsigned out_idx = 0;
-      unsigned num_params = LLVMCountParams(parts[part].value);
-
-      /* Merged shaders are executed conditionally depending
-       * on the number of enabled threads passed in the input SGPRs. */
-      if (si_is_multi_part_shader(ctx->shader) && part == 0) {
-         if (same_thread_count) {
-            struct ac_arg arg;
-            arg.arg_index = 3;
-            arg.used = true;
-
-            si_init_exec_from_input(ctx, arg, 0);
-         } else {
-            LLVMValueRef ena, count = initial[3];
+   LLVMValueRef params[AC_MAX_ARGS];
+   unsigned num_params = LLVMCountParams(ctx->main_fn.value);
+   LLVMGetParams(ctx->main_fn.value, params);
 
-            count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
-            ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
-            ac_build_ifcc(&ctx->ac, ena, 6506);
-         }
-      }
+   /* wrapper function has same parameter as first part shader */
+   LLVMValueRef ret =
+      ac_build_call(&ctx->ac, parts[0].pointee_type, parts[0].value, params, num_params);
 
-      /* Derive arguments for the next part from outputs of the
-       * previous one.
-       */
-      for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
-         LLVMValueRef param;
-         LLVMTypeRef param_type;
-         bool is_sgpr;
-         unsigned param_size;
-         LLVMValueRef arg = NULL;
-
-         param = LLVMGetParam(parts[part].value, param_idx);
-         param_type = LLVMTypeOf(param);
-         param_size = ac_get_type_size(param_type) / 4;
-         is_sgpr = ac_is_sgpr_param(param);
-
-         if (is_sgpr) {
-            ac_add_function_attr(ctx->ac.context, parts[part].value, param_idx + 1, "inreg");
-         } else if (out_idx < num_out_sgpr) {
-            /* Skip returned SGPRs the current part doesn't
-             * declare on the input. */
-            out_idx = num_out_sgpr;
-         }
+   if (same_thread_count) {
+      LLVMTypeRef type = LLVMTypeOf(ret);
+      assert(LLVMGetTypeKind(type) == LLVMStructTypeKind);
+
+      /* output of first part shader is the input of the second part */
+      num_params = LLVMCountStructElementTypes(type);
+      assert(num_params == LLVMCountParams(parts[1].value));
 
-         assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
+      for (unsigned i = 0; i < num_params; i++) {
+         params[i] = LLVMBuildExtractValue(builder, ret, i, "");
 
-         if (param_size == 1)
-            arg = out[out_idx];
-         else
-            arg = ac_build_gather_values(&ctx->ac, &out[out_idx], param_size);
+         /* Convert return value to same type as next shader's input param. */
+         LLVMTypeRef ret_type = LLVMTypeOf(params[i]);
+         LLVMTypeRef param_type = LLVMTypeOf(LLVMGetParam(parts[1].value, i));
+         assert(ac_get_type_size(ret_type) == 4);
+         assert(ac_get_type_size(param_type) == 4);
 
-         if (LLVMTypeOf(arg) != param_type) {
+         if (ret_type != param_type) {
             if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-               if (LLVMGetPointerAddressSpace(param_type) == AC_ADDR_SPACE_CONST_32BIT) {
-                  arg = LLVMBuildBitCast(builder, arg, ctx->ac.i32, "");
-                  arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
-               } else {
-                  arg = LLVMBuildBitCast(builder, arg, ctx->ac.i64, "");
-                  arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
-               }
+               assert(LLVMGetPointerAddressSpace(param_type) == AC_ADDR_SPACE_CONST_32BIT);
+               assert(ret_type == ctx->ac.i32);
+
+               params[i] = LLVMBuildIntToPtr(builder, params[i], param_type, "");
             } else {
-               arg = LLVMBuildBitCast(builder, arg, param_type, "");
+               params[i] = LLVMBuildBitCast(builder, params[i], param_type, "");
             }
          }
-
-         in[param_idx] = arg;
-         out_idx += param_size;
       }
+   } else {
+      ac_build_endif(&ctx->ac, 6506);
 
-      ret = ac_build_call(&ctx->ac, parts[part].pointee_type, parts[part].value, in, num_params);
-
-      if (!same_thread_count &&
-          si_is_multi_part_shader(ctx->shader) && part + 1 == next_shader_first_part) {
-         ac_build_endif(&ctx->ac, 6506);
-
-         /* The second half of the merged shader should use
-          * the inputs from the toplevel (wrapper) function,
-          * not the return value from the last call.
-          *
-          * That's because the last call was executed condi-
-          * tionally, so we can't consume it in the main
-          * block.
-          */
-         memcpy(out, initial, sizeof(initial));
-         num_out = initial_num_out;
-         num_out_sgpr = initial_num_out_sgpr;
-
-         /* Execute the second shader conditionally based on the number of
-          * enabled threads there.
-          */
-         if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-            LLVMValueRef ena, count = initial[3];
+      if (ctx->stage == MESA_SHADER_TESS_CTRL) {
+         LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
+         count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), "");
+         count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
 
-            count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), "");
-            count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
-            ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
-            ac_build_ifcc(&ctx->ac, ena, 6507);
-         }
-         continue;
+         LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
+         ac_build_ifcc(&ctx->ac, ena, 6507);
       }
 
-      /* Extract the returned GPRs. */
-      ret_type = LLVMTypeOf(ret);
-      num_out = 0;
-      num_out_sgpr = 0;
-
-      if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
-         assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
-
-         unsigned ret_size = LLVMCountStructElementTypes(ret_type);
-
-         for (unsigned i = 0; i < ret_size; ++i) {
-            LLVMValueRef val = LLVMBuildExtractValue(builder, ret, i, "");
-
-            assert(num_out < ARRAY_SIZE(out));
-            out[num_out++] = val;
-
-            if (LLVMTypeOf(val) == ctx->ac.i32) {
-               assert(num_out_sgpr + 1 == num_out);
-               num_out_sgpr = num_out;
+      /* The second half of the merged shader should use
+       * the inputs from the toplevel (wrapper) function,
+       * not the return value from the last call.
+       *
+       * That's because the last call was executed condi-
+       * tionally, so we can't consume it in the main
+       * block.
+       */
+      unsigned num_part_params = LLVMCountParams(parts[1].value);
+      for (unsigned i = 0, j = 0; i < num_part_params; i++) {
+         LLVMValueRef param = LLVMGetParam(parts[1].value, i);
+         LLVMTypeRef type = LLVMTypeOf(param);
+
+         bool found = false;
+         for ( ; j < num_params; j++) {
+            /* skip different type params */
+            if (LLVMTypeOf(params[j]) == type) {
+               params[i] = params[j++];
+               found = true;
+               break;
             }
          }
+         assert(found);
       }
+
+      num_params = num_part_params;
    }
 
+   ac_build_call(&ctx->ac, parts[1].pointee_type, parts[1].value, params, num_params);
+
    /* Close the conditional wrapping the second shader. */
-   if (ctx->stage == MESA_SHADER_TESS_CTRL &&
-       !same_thread_count && si_is_multi_part_shader(ctx->shader))
+   if (ctx->stage == MESA_SHADER_TESS_CTRL && !same_thread_count)
       ac_build_endif(&ctx->ac, 6507);
 
-   if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
-      LLVMBuildRetVoid(builder);
-   else
-      LLVMBuildRet(builder, ret);
+   LLVMBuildRetVoid(builder);
 }
 
 static LLVMValueRef si_llvm_load_intrinsic(struct ac_shader_abi *abi, nir_intrinsic_instr *intrin)
@@ -1053,19 +888,12 @@ bool si_llvm_compile_shader(struct si_screen *sscreen, struct ac_llvm_compiler *
 
       parts[0] = ctx.main_fn;
 
-      /* Preserve main arguments. */
-      enum ac_arg_type main_arg_types[AC_MAX_ARGS];
-      for (int i = 0; i < ctx.args->ac.arg_count; i++)
-         main_arg_types[i] = ctx.args->ac.args[i].type;
-      main_arg_types[MIN2(AC_MAX_ARGS - 1, ctx.args->ac.arg_count)] = AC_ARG_INVALID;
-
       /* Reset the shader context. */
       ctx.shader = shader;
       ctx.stage = sel->stage;
 
       bool same_thread_count = shader->key.ge.opt.same_patch_vertices;
-      si_build_wrapper_function(&ctx, parts, 2, 0, 1, main_arg_types,
-                                same_thread_count);
+      si_build_wrapper_function(&ctx, parts, same_thread_count);
    }
 
    si_llvm_optimize_module(&ctx);