}
/**
- * Given a list of shader part functions, build a wrapper function that
+ * Given two parts (LS/HS or ES/GS) of a merged shader, build a wrapper function that
* runs them in sequence to form a monolithic shader.
*/
-void si_build_wrapper_function(struct si_shader_context *ctx, struct ac_llvm_pointer *parts,
- unsigned num_parts, unsigned main_part,
- unsigned next_shader_first_part,
- enum ac_arg_type *main_arg_types, bool same_thread_count)
+static void si_build_wrapper_function(struct si_shader_context *ctx,
+ struct ac_llvm_pointer parts[2],
+ bool same_thread_count)
{
LLVMBuilderRef builder = ctx->ac.builder;
- /* PS epilog has one arg per color component; gfx9 merged shader
- * prologs need to forward 40 SGPRs.
- */
- LLVMValueRef initial[AC_MAX_ARGS], out[AC_MAX_ARGS];
- LLVMTypeRef function_type;
- unsigned num_first_params;
- unsigned num_out, initial_num_out;
- ASSERTED unsigned num_out_sgpr; /* used in debug checks */
- ASSERTED unsigned initial_num_out_sgpr; /* used in debug checks */
- unsigned num_sgprs, num_vgprs;
- unsigned gprs;
-
- memset(ctx->args, 0, sizeof(*ctx->args));
-
- for (unsigned i = 0; i < num_parts; ++i) {
+
+ for (unsigned i = 0; i < 2; ++i) {
ac_add_function_attr(ctx->ac.context, parts[i].value, -1, "alwaysinline");
LLVMSetLinkage(parts[i].value, LLVMPrivateLinkage);
}
- /* The parameters of the wrapper function correspond to those of the
- * first part in terms of SGPRs and VGPRs, but we use the types of the
- * main part to get the right types. This is relevant for the
- * dereferenceable attribute on descriptor table pointers.
- */
- num_sgprs = 0;
- num_vgprs = 0;
-
- function_type = parts[0].pointee_type;
- num_first_params = LLVMCountParamTypes(function_type);
-
- for (unsigned i = 0; i < num_first_params; ++i) {
- LLVMValueRef param = LLVMGetParam(parts[0].value, i);
-
- if (ac_is_sgpr_param(param)) {
- assert(num_vgprs == 0);
- num_sgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
- } else {
- num_vgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
- }
- }
-
- gprs = 0;
- while (gprs < num_sgprs + num_vgprs) {
- LLVMValueRef param = LLVMGetParam(parts[main_part].value, ctx->args->ac.arg_count);
- LLVMTypeRef type = LLVMTypeOf(param);
- unsigned size = ac_get_type_size(type) / 4;
- enum ac_arg_type arg_type = main_arg_types[ctx->args->ac.arg_count];
- assert(arg_type != AC_ARG_INVALID);
-
- ac_add_arg(&ctx->args->ac, gprs < num_sgprs ? AC_ARG_SGPR : AC_ARG_VGPR, size, arg_type, NULL);
-
- assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
- assert(gprs + size <= num_sgprs + num_vgprs &&
- (gprs >= num_sgprs || gprs + size <= num_sgprs));
-
- gprs += size;
- }
-
- /* Prepare the return type. */
- unsigned num_returns = 0;
- LLVMTypeRef returns[AC_MAX_ARGS], last_func_type, return_type;
+ si_llvm_create_func(ctx, "wrapper", NULL, 0, si_get_max_workgroup_size(ctx->shader));
- last_func_type = parts[num_parts - 1].pointee_type;
- return_type = LLVMGetReturnType(last_func_type);
-
- switch (LLVMGetTypeKind(return_type)) {
- case LLVMStructTypeKind:
- num_returns = LLVMCountStructElementTypes(return_type);
- assert(num_returns <= ARRAY_SIZE(returns));
- LLVMGetStructElementTypes(return_type, returns);
- break;
- case LLVMVoidTypeKind:
- break;
- default:
- unreachable("unexpected type");
- }
-
- si_llvm_create_func(ctx, "wrapper", returns, num_returns,
- si_get_max_workgroup_size(ctx->shader));
-
- if (si_is_merged_shader(ctx->shader) && !same_thread_count)
+ if (same_thread_count) {
+ si_init_exec_from_input(ctx, ctx->args->ac.merged_wave_info, 0);
+ } else {
ac_init_exec_full_mask(&ctx->ac);
- /* Record the arguments of the function as if they were an output of
- * a previous part.
- */
- num_out = 0;
- num_out_sgpr = 0;
-
- for (unsigned i = 0; i < ctx->args->ac.arg_count; ++i) {
- LLVMValueRef param = LLVMGetParam(ctx->main_fn.value, i);
- LLVMTypeRef param_type = LLVMTypeOf(param);
- LLVMTypeRef out_type = ctx->args->ac.args[i].file == AC_ARG_SGPR ? ctx->ac.i32 : ctx->ac.f32;
- unsigned size = ac_get_type_size(param_type) / 4;
-
- if (size == 1) {
- if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
- param = LLVMBuildPtrToInt(builder, param, ctx->ac.i32, "");
- param_type = ctx->ac.i32;
- }
-
- if (param_type != out_type)
- param = LLVMBuildBitCast(builder, param, out_type, "");
- out[num_out++] = param;
- } else {
- LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
-
- if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
- param = LLVMBuildPtrToInt(builder, param, ctx->ac.i64, "");
- param_type = ctx->ac.i64;
- }
-
- if (param_type != vector_type)
- param = LLVMBuildBitCast(builder, param, vector_type, "");
-
- for (unsigned j = 0; j < size; ++j)
- out[num_out++] =
- LLVMBuildExtractElement(builder, param, LLVMConstInt(ctx->ac.i32, j, 0), "");
- }
+ LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
+ count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
- if (ctx->args->ac.args[i].file == AC_ARG_SGPR)
- num_out_sgpr = num_out;
+ LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
+ ac_build_ifcc(&ctx->ac, ena, 6506);
}
- memcpy(initial, out, sizeof(out));
- initial_num_out = num_out;
- initial_num_out_sgpr = num_out_sgpr;
-
- /* Now chain the parts. */
- LLVMValueRef ret = NULL;
- for (unsigned part = 0; part < num_parts; ++part) {
- LLVMValueRef in[AC_MAX_ARGS];
- LLVMTypeRef ret_type;
- unsigned out_idx = 0;
- unsigned num_params = LLVMCountParams(parts[part].value);
-
- /* Merged shaders are executed conditionally depending
- * on the number of enabled threads passed in the input SGPRs. */
- if (si_is_multi_part_shader(ctx->shader) && part == 0) {
- if (same_thread_count) {
- struct ac_arg arg;
- arg.arg_index = 3;
- arg.used = true;
-
- si_init_exec_from_input(ctx, arg, 0);
- } else {
- LLVMValueRef ena, count = initial[3];
+ LLVMValueRef params[AC_MAX_ARGS];
+ unsigned num_params = LLVMCountParams(ctx->main_fn.value);
+ LLVMGetParams(ctx->main_fn.value, params);
- count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
- ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
- ac_build_ifcc(&ctx->ac, ena, 6506);
- }
- }
+ /* wrapper function has same parameter as first part shader */
+ LLVMValueRef ret =
+ ac_build_call(&ctx->ac, parts[0].pointee_type, parts[0].value, params, num_params);
- /* Derive arguments for the next part from outputs of the
- * previous one.
- */
- for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
- LLVMValueRef param;
- LLVMTypeRef param_type;
- bool is_sgpr;
- unsigned param_size;
- LLVMValueRef arg = NULL;
-
- param = LLVMGetParam(parts[part].value, param_idx);
- param_type = LLVMTypeOf(param);
- param_size = ac_get_type_size(param_type) / 4;
- is_sgpr = ac_is_sgpr_param(param);
-
- if (is_sgpr) {
- ac_add_function_attr(ctx->ac.context, parts[part].value, param_idx + 1, "inreg");
- } else if (out_idx < num_out_sgpr) {
- /* Skip returned SGPRs the current part doesn't
- * declare on the input. */
- out_idx = num_out_sgpr;
- }
+ if (same_thread_count) {
+ LLVMTypeRef type = LLVMTypeOf(ret);
+ assert(LLVMGetTypeKind(type) == LLVMStructTypeKind);
+
+ /* output of first part shader is the input of the second part */
+ num_params = LLVMCountStructElementTypes(type);
+ assert(num_params == LLVMCountParams(parts[1].value));
- assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
+ for (unsigned i = 0; i < num_params; i++) {
+ params[i] = LLVMBuildExtractValue(builder, ret, i, "");
- if (param_size == 1)
- arg = out[out_idx];
- else
- arg = ac_build_gather_values(&ctx->ac, &out[out_idx], param_size);
+ /* Convert return value to same type as next shader's input param. */
+ LLVMTypeRef ret_type = LLVMTypeOf(params[i]);
+ LLVMTypeRef param_type = LLVMTypeOf(LLVMGetParam(parts[1].value, i));
+ assert(ac_get_type_size(ret_type) == 4);
+ assert(ac_get_type_size(param_type) == 4);
- if (LLVMTypeOf(arg) != param_type) {
+ if (ret_type != param_type) {
if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
- if (LLVMGetPointerAddressSpace(param_type) == AC_ADDR_SPACE_CONST_32BIT) {
- arg = LLVMBuildBitCast(builder, arg, ctx->ac.i32, "");
- arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
- } else {
- arg = LLVMBuildBitCast(builder, arg, ctx->ac.i64, "");
- arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
- }
+ assert(LLVMGetPointerAddressSpace(param_type) == AC_ADDR_SPACE_CONST_32BIT);
+ assert(ret_type == ctx->ac.i32);
+
+ params[i] = LLVMBuildIntToPtr(builder, params[i], param_type, "");
} else {
- arg = LLVMBuildBitCast(builder, arg, param_type, "");
+ params[i] = LLVMBuildBitCast(builder, params[i], param_type, "");
}
}
-
- in[param_idx] = arg;
- out_idx += param_size;
}
+ } else {
+ ac_build_endif(&ctx->ac, 6506);
- ret = ac_build_call(&ctx->ac, parts[part].pointee_type, parts[part].value, in, num_params);
-
- if (!same_thread_count &&
- si_is_multi_part_shader(ctx->shader) && part + 1 == next_shader_first_part) {
- ac_build_endif(&ctx->ac, 6506);
-
- /* The second half of the merged shader should use
- * the inputs from the toplevel (wrapper) function,
- * not the return value from the last call.
- *
- * That's because the last call was executed condi-
- * tionally, so we can't consume it in the main
- * block.
- */
- memcpy(out, initial, sizeof(initial));
- num_out = initial_num_out;
- num_out_sgpr = initial_num_out_sgpr;
-
- /* Execute the second shader conditionally based on the number of
- * enabled threads there.
- */
- if (ctx->stage == MESA_SHADER_TESS_CTRL) {
- LLVMValueRef ena, count = initial[3];
+ if (ctx->stage == MESA_SHADER_TESS_CTRL) {
+ LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
+ count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), "");
+ count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
- count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), "");
- count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
- ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
- ac_build_ifcc(&ctx->ac, ena, 6507);
- }
- continue;
+ LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
+ ac_build_ifcc(&ctx->ac, ena, 6507);
}
- /* Extract the returned GPRs. */
- ret_type = LLVMTypeOf(ret);
- num_out = 0;
- num_out_sgpr = 0;
-
- if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
- assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
-
- unsigned ret_size = LLVMCountStructElementTypes(ret_type);
-
- for (unsigned i = 0; i < ret_size; ++i) {
- LLVMValueRef val = LLVMBuildExtractValue(builder, ret, i, "");
-
- assert(num_out < ARRAY_SIZE(out));
- out[num_out++] = val;
-
- if (LLVMTypeOf(val) == ctx->ac.i32) {
- assert(num_out_sgpr + 1 == num_out);
- num_out_sgpr = num_out;
+ /* The second half of the merged shader should use
+ * the inputs from the toplevel (wrapper) function,
+ * not the return value from the last call.
+ *
+ * That's because the last call was executed condi-
+ * tionally, so we can't consume it in the main
+ * block.
+ */
+ unsigned num_part_params = LLVMCountParams(parts[1].value);
+ for (unsigned i = 0, j = 0; i < num_part_params; i++) {
+ LLVMValueRef param = LLVMGetParam(parts[1].value, i);
+ LLVMTypeRef type = LLVMTypeOf(param);
+
+ bool found = false;
+ for ( ; j < num_params; j++) {
+ /* skip different type params */
+ if (LLVMTypeOf(params[j]) == type) {
+ params[i] = params[j++];
+ found = true;
+ break;
}
}
+ assert(found);
}
+
+ num_params = num_part_params;
}
+ ac_build_call(&ctx->ac, parts[1].pointee_type, parts[1].value, params, num_params);
+
/* Close the conditional wrapping the second shader. */
- if (ctx->stage == MESA_SHADER_TESS_CTRL &&
- !same_thread_count && si_is_multi_part_shader(ctx->shader))
+ if (ctx->stage == MESA_SHADER_TESS_CTRL && !same_thread_count)
ac_build_endif(&ctx->ac, 6507);
- if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
- LLVMBuildRetVoid(builder);
- else
- LLVMBuildRet(builder, ret);
+ LLVMBuildRetVoid(builder);
}
static LLVMValueRef si_llvm_load_intrinsic(struct ac_shader_abi *abi, nir_intrinsic_instr *intrin)
parts[0] = ctx.main_fn;
- /* Preserve main arguments. */
- enum ac_arg_type main_arg_types[AC_MAX_ARGS];
- for (int i = 0; i < ctx.args->ac.arg_count; i++)
- main_arg_types[i] = ctx.args->ac.args[i].type;
- main_arg_types[MIN2(AC_MAX_ARGS - 1, ctx.args->ac.arg_count)] = AC_ARG_INVALID;
-
/* Reset the shader context. */
ctx.shader = shader;
ctx.stage = sel->stage;
bool same_thread_count = shader->key.ge.opt.same_patch_vertices;
- si_build_wrapper_function(&ctx, parts, 2, 0, 1, main_arg_types,
- same_thread_count);
+ si_build_wrapper_function(&ctx, parts, same_thread_count);
}
si_llvm_optimize_module(&ctx);