From 3994fb1e19ce90d56559de68451c20a3c7bd51a0 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Wed, 17 May 2023 12:01:36 +1000 Subject: [PATCH] llvmpipe/cs: add task/mesh shader support to compute shader builder. This allows generating task and mesh variants of compute shaders. It adds: - vertex and primitive outputs support - aos writing. - payload support - mesh iface for the output and count callbacks. - draw_id - multiple iteration support to the exec fn to allow launches in multiple passes to reduce memory usage Reviewed-by: Roland Scheidegger Part-of: --- src/gallium/drivers/llvmpipe/lp_jit.h | 3 + src/gallium/drivers/llvmpipe/lp_state_cs.c | 391 ++++++++++++++++++++++++++++- src/gallium/drivers/llvmpipe/lp_state_cs.h | 4 + 3 files changed, 389 insertions(+), 9 deletions(-) diff --git a/src/gallium/drivers/llvmpipe/lp_jit.h b/src/gallium/drivers/llvmpipe/lp_jit.h index c1e006a..68de17a 100644 --- a/src/gallium/drivers/llvmpipe/lp_jit.h +++ b/src/gallium/drivers/llvmpipe/lp_jit.h @@ -49,6 +49,7 @@ struct lp_fragment_shader_variant; struct lp_compute_shader_variant; struct lp_rast_state; struct llvmpipe_screen; +struct vertex_header; struct lp_jit_viewport { @@ -374,6 +375,8 @@ typedef void uint32_t grid_size_y, uint32_t grid_size_z, uint32_t work_dim, + uint32_t draw_id, + struct vertex_header *io, /* mesh shader only */ struct lp_jit_cs_thread_data *thread_data); void diff --git a/src/gallium/drivers/llvmpipe/lp_state_cs.c b/src/gallium/drivers/llvmpipe/lp_state_cs.c index 023b79a..071538f 100644 --- a/src/gallium/drivers/llvmpipe/lp_state_cs.c +++ b/src/gallium/drivers/llvmpipe/lp_state_cs.c @@ -26,6 +26,7 @@ #include "util/u_memory.h" #include "util/os_time.h" #include "util/u_dump.h" +#include "util/u_prim.h" #include "util/u_string.h" #include "tgsi/tgsi_dump.h" #include "tgsi/tgsi_parse.h" @@ -33,12 +34,14 @@ #include "gallivm/lp_bld_debug.h" #include "gallivm/lp_bld_intr.h" #include "gallivm/lp_bld_flow.h" +#include "gallivm/lp_bld_pack.h" #include "gallivm/lp_bld_gather.h" #include "gallivm/lp_bld_coro.h" #include "gallivm/lp_bld_nir.h" #include "gallivm/lp_bld_jit_sample.h" #include "lp_state_cs.h" #include "lp_context.h" +#include "lp_setup_context.h" #include "lp_debug.h" #include "lp_state.h" #include "lp_perf.h" @@ -52,6 +55,8 @@ #include "nir_serialize.h" #include "draw/draw_context.h" +#include "draw/draw_llvm.h" +#include "draw/draw_mesh_prim.h" /** Fragment shader number (for debugging) */ static unsigned cs_no = 0; @@ -60,12 +65,19 @@ static unsigned mesh_no = 0; struct lp_cs_job_info { unsigned grid_size[3]; + unsigned iter_size[3]; unsigned grid_base[3]; unsigned block_size[3]; unsigned req_local_mem; unsigned work_dim; + unsigned draw_id; bool zero_initialize_shared_memory; + bool use_iters; struct lp_cs_exec *current; + struct vertex_header *io; + size_t io_stride; + void *payload; + size_t payload_stride; }; enum { @@ -81,6 +93,8 @@ enum { CS_ARG_GRID_SIZE_Y, CS_ARG_GRID_SIZE_Z, CS_ARG_WORK_DIM, + CS_ARG_DRAW_ID, + CS_ARG_VERTEX_DATA, CS_ARG_PER_THREAD_DATA, CS_ARG_OUTER_COUNT, CS_ARG_CORO_X_LOOPS = CS_ARG_OUTER_COUNT, @@ -90,9 +104,214 @@ enum { CS_ARG_CORO_BLOCK_Z_SIZE, CS_ARG_CORO_IDX, CS_ARG_CORO_MEM, + CS_ARG_CORO_OUTPUTS, CS_ARG_MAX, }; +struct lp_mesh_llvm_iface { + struct lp_build_mesh_iface base; + + LLVMValueRef vertex_count; + LLVMValueRef prim_count; + LLVMValueRef outputs; +}; + +static inline const struct lp_mesh_llvm_iface * +lp_mesh_llvm_iface(const struct lp_build_mesh_iface *iface) +{ + return (const struct lp_mesh_llvm_iface *)iface; +} + + +static LLVMTypeRef +create_mesh_jit_output_type_deref(struct gallivm_state *gallivm) +{ + LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context); + LLVMTypeRef output_array; + + output_array = LLVMArrayType(float_type, TGSI_NUM_CHANNELS); /* num channels */ + output_array = LLVMArrayType(output_array, PIPE_MAX_SHADER_OUTPUTS); /* num attrs per vertex */ + return output_array; +} + +static void +lp_mesh_llvm_emit_store_output(const struct lp_build_mesh_iface *mesh_iface, + struct lp_build_context *bld, + unsigned name, + boolean is_vindex_indirect, + LLVMValueRef vertex_index, + boolean is_aindex_indirect, + LLVMValueRef attrib_index, + boolean is_sindex_indirect, + LLVMValueRef swizzle_index, + LLVMValueRef value, + LLVMValueRef mask_vec) +{ + const struct lp_mesh_llvm_iface *mesh = lp_mesh_llvm_iface(mesh_iface); + struct gallivm_state *gallivm = bld->gallivm; + LLVMBuilderRef builder = gallivm->builder; + LLVMValueRef indices[3]; + LLVMValueRef res; + struct lp_type type = bld->type; + LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm); + + if (is_vindex_indirect || is_aindex_indirect || is_sindex_indirect) { + for (int i = 0; i < type.length; ++i) { + LLVMValueRef idx = lp_build_const_int32(gallivm, i); + LLVMValueRef vert_chan_index = vertex_index ? vertex_index : lp_build_const_int32(gallivm, 0); + LLVMValueRef attr_chan_index = attrib_index; + LLVMValueRef swiz_chan_index = swizzle_index; + LLVMValueRef channel_vec; + + if (is_vindex_indirect) { + vert_chan_index = LLVMBuildExtractElement(builder, + vertex_index, idx, ""); + } + if (is_aindex_indirect) { + attr_chan_index = LLVMBuildExtractElement(builder, + attrib_index, idx, ""); + } + + if (is_sindex_indirect) { + swiz_chan_index = LLVMBuildExtractElement(builder, + swizzle_index, idx, ""); + } + + indices[0] = vert_chan_index; + indices[1] = attr_chan_index; + indices[2] = swiz_chan_index; + + channel_vec = LLVMBuildGEP2(builder, output_type, mesh->outputs, indices, 3, ""); + + res = LLVMBuildExtractElement(builder, value, idx, ""); + + struct lp_build_if_state ifthen; + LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, mask_vec, lp_build_const_int_vec(gallivm, bld->type, 0), ""); + cond = LLVMBuildExtractElement(gallivm->builder, cond, idx, ""); + lp_build_if(&ifthen, gallivm, cond); + LLVMBuildStore(builder, res, channel_vec); + lp_build_endif(&ifthen); + } + } else { + indices[0] = vertex_index ? vertex_index : lp_build_const_int32(gallivm, 0); + indices[1] = attrib_index; + indices[2] = swizzle_index; + + res = LLVMBuildGEP2(builder, output_type, mesh->outputs, indices, 3, ""); + for (unsigned i = 0; i < type.length; ++i) { + LLVMValueRef idx = lp_build_const_int32(gallivm, i); + LLVMValueRef val = LLVMBuildExtractElement(builder, value, idx, ""); + + struct lp_build_if_state ifthen; + LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, mask_vec, lp_build_const_int_vec(gallivm, bld->type, 0), ""); + cond = LLVMBuildExtractElement(gallivm->builder, cond, idx, ""); + lp_build_if(&ifthen, gallivm, cond); + LLVMBuildStore(builder, val, res); + lp_build_endif(&ifthen); + } + } +} + +static void +lp_mesh_emit_vertex_and_primitive_count(const struct lp_build_mesh_iface *mesh_iface, + struct lp_build_context *bld, + LLVMValueRef vertices_count, + LLVMValueRef primitives_count) +{ + const struct lp_mesh_llvm_iface *mesh = lp_mesh_llvm_iface(mesh_iface); + struct gallivm_state *gallivm = bld->gallivm; + + LLVMBuildStore(gallivm->builder, vertices_count, mesh->vertex_count); + LLVMBuildStore(gallivm->builder, primitives_count, mesh->prim_count); +} + +static void +mesh_convert_to_aos(struct gallivm_state *gallivm, + nir_shader *nir, + bool vert_only, + LLVMTypeRef io_type, + LLVMValueRef io, + LLVMValueRef outputs, + LLVMValueRef clipmask, + LLVMValueRef vertex_index, + struct lp_type soa_type, + int primid_slot, + boolean need_edgeflag) +{ + LLVMBuilderRef builder = gallivm->builder; + LLVMValueRef inds[3]; + LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm); +#if DEBUG_STORE + lp_build_printf(gallivm, " # storing begin\n"); +#endif + int first_per_prim_attrib = -1; + nir_foreach_shader_out_variable(var, nir) { + if (var->data.per_primitive) { + first_per_prim_attrib = var->data.driver_location; + break; + } + } + nir_foreach_shader_out_variable(var, nir) { + + if (vert_only && var->data.per_primitive) + continue; + if (!vert_only && !var->data.per_primitive) + continue; + int attrib = var->data.driver_location; + int slots = glsl_count_attribute_slots(glsl_get_array_element(var->type), false); + + for (unsigned s = 0; s < slots; s++) { + LLVMValueRef soa[TGSI_NUM_CHANNELS]; + LLVMValueRef aos[LP_MAX_VECTOR_WIDTH / 32]; + for (unsigned chan = 0; chan < TGSI_NUM_CHANNELS; ++chan) { + inds[0] = vertex_index; + inds[1] = lp_build_const_int32(gallivm, attrib); + inds[2] = lp_build_const_int32(gallivm, chan); + + LLVMValueRef res = LLVMBuildGEP2(builder, output_type, outputs, inds, 3, ""); + LLVMTypeRef single_type = (attrib == primid_slot) ? lp_build_int_elem_type(gallivm, soa_type) : lp_build_elem_type(gallivm, soa_type); + LLVMValueRef out = LLVMBuildLoad2(builder, single_type, res, ""); + lp_build_name(out, "output%u.%c", attrib, "xyzw"[chan]); +#if DEBUG_STORE + lp_build_printf(gallivm, "output %d : %d ", + LLVMConstInt(LLVMInt32TypeInContext(gallivm->context), + attrib, 0), + LLVMConstInt(LLVMInt32TypeInContext(gallivm->context), + chan, 0)); + lp_build_print_value(gallivm, "val = ", out); + { + LLVMValueRef iv = + LLVMBuildBitCast(builder, out, lp_build_int_elem_type(gallivm, soa_type), ""); + + lp_build_print_value(gallivm, " ival = ", iv); + } +#endif + soa[chan] = out; + } + LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context); + aos[0] = LLVMGetUndef(LLVMVectorType(float_type, 4)); + for (unsigned i = 0; i < 4; i++) + aos[0] = LLVMBuildInsertElement(builder, aos[0], soa[i], lp_build_const_int32(gallivm, i), ""); + int aos_attrib = attrib; + if (var->data.per_primitive) + aos_attrib -= first_per_prim_attrib; + draw_store_aos_array(gallivm, + soa_type, + io_type, + io, + NULL, + aos, + aos_attrib, + clipmask, + need_edgeflag, var->data.per_primitive); + attrib++; + } + } +#if DEBUG_STORE + lp_build_printf(gallivm, " # storing end\n"); +#endif +} + static void generate_compute(struct llvmpipe_context *lp, struct lp_compute_shader *shader, @@ -108,15 +327,25 @@ generate_compute(struct llvmpipe_context *lp, LLVMValueRef block_x_size_arg, block_y_size_arg, block_z_size_arg; LLVMValueRef grid_x_arg, grid_y_arg, grid_z_arg; LLVMValueRef grid_size_x_arg, grid_size_y_arg, grid_size_z_arg; - LLVMValueRef work_dim_arg, thread_data_ptr; + LLVMValueRef work_dim_arg, draw_id_arg, thread_data_ptr, io_ptr; LLVMBasicBlockRef block; LLVMBuilderRef builder; struct lp_build_sampler_soa *sampler; struct lp_build_image_soa *image; LLVMValueRef function, coro; struct lp_type cs_type; + struct lp_mesh_llvm_iface mesh_iface; + bool is_mesh = false; unsigned i; + LLVMValueRef output_array = NULL; + if (shader->base.type == PIPE_SHADER_IR_NIR) { + struct nir_shader *nir = shader->base.ir.nir; + if (nir->info.stage == MESA_SHADER_MESH) { + is_mesh = true; + } + } + /* * This function has two parts * a) setup the coroutine execution environment loop. @@ -146,6 +375,11 @@ generate_compute(struct llvmpipe_context *lp, arg_types[CS_ARG_GRID_SIZE_Y] = int32_type; /* grid_size_y */ arg_types[CS_ARG_GRID_SIZE_Z] = int32_type; /* grid_size_z */ arg_types[CS_ARG_WORK_DIM] = int32_type; /* work dim */ + arg_types[CS_ARG_DRAW_ID] = int32_type; /* draw id */ + if (variant->jit_vertex_header_ptr_type) + arg_types[CS_ARG_VERTEX_DATA] = variant->jit_vertex_header_ptr_type; /* mesh shaders only */ + else + arg_types[CS_ARG_VERTEX_DATA] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* mesh shaders only */ arg_types[CS_ARG_PER_THREAD_DATA] = variant->jit_cs_thread_data_ptr_type; /* per thread data */ arg_types[CS_ARG_CORO_X_LOOPS] = int32_type; /* coro only - num X loops */ arg_types[CS_ARG_CORO_PARTIALS] = int32_type; /* coro only - partials */ @@ -154,11 +388,13 @@ generate_compute(struct llvmpipe_context *lp, arg_types[CS_ARG_CORO_BLOCK_Z_SIZE] = int32_type; /* coro block_z_size */ arg_types[CS_ARG_CORO_IDX] = int32_type; /* coro idx */ arg_types[CS_ARG_CORO_MEM] = LLVMPointerType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), 0); + arg_types[CS_ARG_CORO_OUTPUTS] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* mesh shaders only */ + func_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context), arg_types, CS_ARG_OUTER_COUNT, 0); coro_func_type = LLVMFunctionType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), - arg_types, CS_ARG_MAX, 0); + arg_types, CS_ARG_MAX - (!is_mesh), 0); function = LLVMAddFunction(gallivm->module, func_name, func_type); LLVMSetFunctionCallConv(function, LLVMCCallConv); @@ -169,7 +405,7 @@ generate_compute(struct llvmpipe_context *lp, variant->function = function; - for (i = 0; i < CS_ARG_MAX; ++i) { + for (i = 0; i < CS_ARG_MAX - !is_mesh; ++i) { if (LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind) { lp_add_function_attr(coro, i + 1, LP_FUNC_ATTR_NOALIAS); if (i < CS_ARG_OUTER_COUNT) @@ -192,6 +428,8 @@ generate_compute(struct llvmpipe_context *lp, grid_size_y_arg = LLVMGetParam(function, CS_ARG_GRID_SIZE_Y); grid_size_z_arg = LLVMGetParam(function, CS_ARG_GRID_SIZE_Z); work_dim_arg = LLVMGetParam(function, CS_ARG_WORK_DIM); + draw_id_arg = LLVMGetParam(function, CS_ARG_DRAW_ID); + io_ptr = LLVMGetParam(function, CS_ARG_VERTEX_DATA); thread_data_ptr = LLVMGetParam(function, CS_ARG_PER_THREAD_DATA); lp_build_name(context_ptr, "context"); @@ -206,7 +444,9 @@ generate_compute(struct llvmpipe_context *lp, lp_build_name(grid_size_y_arg, "grid_size_y"); lp_build_name(grid_size_z_arg, "grid_size_z"); lp_build_name(work_dim_arg, "work_dim"); + lp_build_name(draw_id_arg, "draw_id"); lp_build_name(thread_data_ptr, "thread_data"); + lp_build_name(io_ptr, "vertex_io"); block = LLVMAppendBasicBlockInContext(gallivm->context, function, "entry"); builder = gallivm->builder; @@ -217,6 +457,12 @@ generate_compute(struct llvmpipe_context *lp, key->nr_sampler_views)); image = lp_bld_llvm_image_soa_create(lp_cs_variant_key_images(key), key->nr_images); + if (is_mesh) { + struct nir_shader *nir = shader->base.ir.nir; + LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm); + output_array = lp_build_array_alloca(gallivm, output_type, lp_build_const_int32(gallivm, align(MAX2(nir->info.mesh.max_primitives_out, nir->info.mesh.max_vertices_out), 8)), "outputs"); + } + struct lp_build_loop_state loop_state[4]; LLVMValueRef num_x_loop; LLVMValueRef vec_length = lp_build_const_int32(gallivm, cs_type.length); @@ -265,6 +511,8 @@ generate_compute(struct llvmpipe_context *lp, args[CS_ARG_GRID_SIZE_Y] = grid_size_y_arg; args[CS_ARG_GRID_SIZE_Z] = grid_size_z_arg; args[CS_ARG_WORK_DIM] = work_dim_arg; + args[CS_ARG_DRAW_ID] = draw_id_arg; + args[CS_ARG_VERTEX_DATA] = io_ptr; args[CS_ARG_PER_THREAD_DATA] = thread_data_ptr; args[CS_ARG_CORO_X_LOOPS] = num_x_loop; args[CS_ARG_CORO_PARTIALS] = partials; @@ -284,6 +532,10 @@ generate_compute(struct llvmpipe_context *lp, args[CS_ARG_CORO_IDX] = coro_hdl_idx; args[CS_ARG_CORO_MEM] = coro_mem; + + if (is_mesh) + args[CS_ARG_CORO_OUTPUTS] = output_array; + LLVMValueRef coro_entry = LLVMBuildGEP2(gallivm->builder, hdl_ptr_type, coro_hdls, &coro_hdl_idx, 1, ""); LLVMValueRef coro_hdl = LLVMBuildLoad2(gallivm->builder, hdl_ptr_type, coro_entry, "coro_hdl"); @@ -293,7 +545,7 @@ generate_compute(struct llvmpipe_context *lp, lp_build_const_int32(gallivm, 0), ""); /* first time here - call the coroutine function entry point */ lp_build_if(&ifstate, gallivm, cmp); - LLVMValueRef coro_ret = LLVMBuildCall2(gallivm->builder, coro_func_type, coro, args, CS_ARG_MAX, ""); + LLVMValueRef coro_ret = LLVMBuildCall2(gallivm->builder, coro_func_type, coro, args, CS_ARG_MAX - !is_mesh, ""); LLVMBuildStore(gallivm->builder, coro_ret, coro_entry); lp_build_else(&ifstate); /* subsequent calls for this invocation - check if done. */ @@ -344,6 +596,8 @@ generate_compute(struct llvmpipe_context *lp, grid_size_y_arg = LLVMGetParam(coro, CS_ARG_GRID_SIZE_Y); grid_size_z_arg = LLVMGetParam(coro, CS_ARG_GRID_SIZE_Z); work_dim_arg = LLVMGetParam(coro, CS_ARG_WORK_DIM); + draw_id_arg = LLVMGetParam(coro, CS_ARG_DRAW_ID); + io_ptr = LLVMGetParam(coro, CS_ARG_VERTEX_DATA); thread_data_ptr = LLVMGetParam(coro, CS_ARG_PER_THREAD_DATA); num_x_loop = LLVMGetParam(coro, CS_ARG_CORO_X_LOOPS); partials = LLVMGetParam(coro, CS_ARG_CORO_PARTIALS); @@ -352,12 +606,15 @@ generate_compute(struct llvmpipe_context *lp, block_z_size_arg = LLVMGetParam(coro, CS_ARG_CORO_BLOCK_Z_SIZE); LLVMValueRef coro_idx = LLVMGetParam(coro, CS_ARG_CORO_IDX); coro_mem = LLVMGetParam(coro, CS_ARG_CORO_MEM); + if (is_mesh) + output_array = LLVMGetParam(coro, CS_ARG_CORO_OUTPUTS); block = LLVMAppendBasicBlockInContext(gallivm->context, coro, "entry"); LLVMPositionBuilderAtEnd(builder, block); { LLVMValueRef consts_ptr; LLVMValueRef ssbo_ptr; LLVMValueRef shared_ptr; + LLVMValueRef payload_ptr; LLVMValueRef kernel_args_ptr; struct lp_build_mask_context mask; struct lp_bld_tgsi_system_values system_values; @@ -372,6 +629,9 @@ generate_compute(struct llvmpipe_context *lp, shared_ptr = lp_jit_cs_thread_data_shared(gallivm, variant->jit_cs_thread_data_type, thread_data_ptr); + payload_ptr = lp_jit_cs_thread_data_payload(gallivm, + variant->jit_cs_thread_data_type, + thread_data_ptr); LLVMValueRef coro_num_hdls = LLVMBuildMul(gallivm->builder, num_x_loop, block_y_size_arg, ""); coro_num_hdls = LLVMBuildMul(gallivm->builder, coro_num_hdls, block_z_size_arg, ""); @@ -410,6 +670,7 @@ generate_compute(struct llvmpipe_context *lp, system_values.grid_size = LLVMBuildInsertElement(builder, system_values.grid_size, gstids[i], lp_build_const_int32(gallivm, i), ""); system_values.work_dim = work_dim_arg; + system_values.draw_id = draw_id_arg; /* subgroup_id = ((z * block_size_x * block_size_y) + (y * block_size_x) + x) / subgroup_size * @@ -470,6 +731,16 @@ generate_compute(struct llvmpipe_context *lp, coro_info.suspend = sus_block; coro_info.cleanup = clean_block; + if (is_mesh) { + LLVMValueRef vertex_count = lp_build_alloca(gallivm, LLVMInt32TypeInContext(gallivm->context), "vertex_count"); + LLVMValueRef primitive_count = lp_build_alloca(gallivm, LLVMInt32TypeInContext(gallivm->context), "prim_count"); + mesh_iface.base.emit_store_output = lp_mesh_llvm_emit_store_output; + mesh_iface.base.emit_vertex_and_primitive_count = lp_mesh_emit_vertex_and_primitive_count; + mesh_iface.vertex_count = vertex_count; + mesh_iface.prim_count = primitive_count; + mesh_iface.outputs = output_array; + } + struct lp_build_tgsi_params params; memset(¶ms, 0, sizeof(params)); @@ -486,11 +757,13 @@ generate_compute(struct llvmpipe_context *lp, params.ssbo_ptr = ssbo_ptr; params.image = image; params.shared_ptr = shared_ptr; + params.payload_ptr = payload_ptr; params.coro = &coro_info; params.kernel_args = kernel_args_ptr; params.aniso_filter_table = lp_jit_resources_aniso_filter_table(gallivm, variant->jit_resources_type, resources_ptr); + params.mesh_iface = &mesh_iface.base; if (shader->base.type == PIPE_SHADER_IR_TGSI) lp_build_tgsi_soa(gallivm, shader->base.tokens, ¶ms, NULL); @@ -498,6 +771,73 @@ generate_compute(struct llvmpipe_context *lp, lp_build_nir_soa(gallivm, shader->base.ir.nir, ¶ms, NULL); + if (is_mesh) { + LLVMTypeRef i32t = LLVMInt32TypeInContext(gallivm->context); + LLVMValueRef clipmask = lp_build_const_int_vec(gallivm, + lp_int_type(cs_type), 0); + + struct lp_build_if_state iter0state; + LLVMValueRef is_iter0 = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, coro_idx, + lp_build_const_int32(gallivm, 0), ""); + LLVMValueRef vertex_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.vertex_count, ""); + LLVMValueRef prim_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.prim_count, ""); + + LLVMValueRef vert_count_ptr, prim_count_ptr; + LLVMValueRef indices = lp_build_const_int32(gallivm, 1); + vert_count_ptr = LLVMBuildGEP2(gallivm->builder, i32t, io_ptr, &indices, 1, ""); + indices = lp_build_const_int32(gallivm, 2); + prim_count_ptr = LLVMBuildGEP2(gallivm->builder, i32t, io_ptr, &indices, 1, ""); + + lp_build_if(&iter0state, gallivm, is_iter0); + LLVMBuildStore(gallivm->builder, vertex_count, vert_count_ptr); + LLVMBuildStore(gallivm->builder, prim_count, prim_count_ptr); + lp_build_endif(&iter0state); + + LLVMBasicBlockRef resume = lp_build_insert_new_block(gallivm, "resume"); + + lp_build_coro_suspend_switch(gallivm, params.coro, resume, false); + LLVMPositionBuilderAtEnd(gallivm->builder, resume); + + vertex_count = LLVMBuildLoad2(gallivm->builder, i32t, vert_count_ptr, ""); + prim_count = LLVMBuildLoad2(gallivm->builder, i32t, prim_count_ptr, ""); + + nir_shader *nir = shader->base.ir.nir; + int per_prim_count = util_bitcount64(nir->info.per_primitive_outputs); + int out_count = util_bitcount64(nir->info.outputs_written); + int per_vert_count = out_count - per_prim_count; + int vsize = (sizeof(struct vertex_header) + per_vert_count * 4 * sizeof(float)) * 8; + int psize = (per_prim_count * 4 * sizeof(float)) * 8; + struct lp_build_loop_state vertex_loop_state; + + lp_build_loop_begin(&vertex_loop_state, gallivm, + lp_build_const_int32(gallivm, 0)); + LLVMValueRef io; + io = LLVMBuildPtrToInt(gallivm->builder, io_ptr, LLVMInt64TypeInContext(gallivm->context), ""); + io = LLVMBuildAdd(builder, io, LLVMBuildZExt(builder, LLVMBuildMul(builder, vertex_loop_state.counter, lp_build_const_int32(gallivm, vsize), ""), LLVMInt64TypeInContext(gallivm->context), ""), ""); + io = LLVMBuildIntToPtr(gallivm->builder, io, LLVMPointerType(LLVMVoidTypeInContext(gallivm->context), 0), ""); + mesh_convert_to_aos(gallivm, shader->base.ir.nir, true, variant->jit_vertex_header_type, + io, output_array, clipmask, + vertex_loop_state.counter, lp_elem_type(cs_type), -1, FALSE); + lp_build_loop_end_cond(&vertex_loop_state, + vertex_count, + NULL, LLVMIntUGE); + + struct lp_build_loop_state prim_loop_state; + lp_build_loop_begin(&prim_loop_state, gallivm, + lp_build_const_int32(gallivm, 0)); + io = LLVMBuildPtrToInt(gallivm->builder, io_ptr, LLVMInt64TypeInContext(gallivm->context), ""); + LLVMValueRef prim_offset = LLVMBuildMul(builder, prim_loop_state.counter, lp_build_const_int32(gallivm, psize), ""); + prim_offset = LLVMBuildAdd(builder, prim_offset, lp_build_const_int32(gallivm, vsize * (nir->info.mesh.max_vertices_out + 8)), ""); + io = LLVMBuildAdd(builder, io, LLVMBuildZExt(builder, prim_offset, LLVMInt64TypeInContext(gallivm->context), ""), ""); + io = LLVMBuildIntToPtr(gallivm->builder, io, LLVMPointerType(LLVMVoidTypeInContext(gallivm->context), 0), ""); + mesh_convert_to_aos(gallivm, shader->base.ir.nir, false, variant->jit_prim_type, + io, output_array, clipmask, + prim_loop_state.counter, lp_elem_type(cs_type), -1, FALSE); + lp_build_loop_end_cond(&prim_loop_state, + prim_count, + NULL, LLVMIntUGE); + } + mask_val = lp_build_mask_end(&mask); lp_build_coro_suspend_switch(gallivm, &coro_info, NULL, true); @@ -837,7 +1177,8 @@ generate_variant(struct llvmpipe_context *lp, memset(variant, 0, sizeof(*variant)); char module_name[64]; - const char *shname = "cs"; + const char *shname = sh_type == PIPE_SHADER_MESH ? "ms" : + (sh_type == PIPE_SHADER_TASK ? "ts" : "cs"); snprintf(module_name, sizeof(module_name), "%s%u_variant%u", shname, shader->no, shader->variants_created); @@ -871,6 +1212,16 @@ generate_variant(struct llvmpipe_context *lp, lp_jit_init_cs_types(variant); + if (sh_type == PIPE_SHADER_MESH) { + struct nir_shader *nir = shader->base.ir.nir; + int per_prim_count = util_bitcount64(nir->info.per_primitive_outputs); + int out_count = util_bitcount64(nir->info.outputs_written); + int per_vert_count = out_count - per_prim_count; + variant->jit_vertex_header_type = lp_build_create_jit_vertex_header_type(variant->gallivm, per_vert_count); + variant->jit_vertex_header_ptr_type = LLVMPointerType(variant->jit_vertex_header_type, 0); + variant->jit_prim_type = LLVMArrayType(LLVMArrayType(LLVMFloatTypeInContext(variant->gallivm->context), 4), per_prim_count); + } + generate_compute(lp, shader, variant); gallivm_compile_module(variant->gallivm); @@ -1443,19 +1794,41 @@ cs_exec_fn(void *init_data, int iter_idx, struct lp_cs_local_mem *lmem) memset(lmem->local_mem_ptr, 0, job_info->req_local_mem); thread_data.shared = lmem->local_mem_ptr; - unsigned grid_z = iter_idx / (job_info->grid_size[0] * job_info->grid_size[1]); - unsigned grid_y = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1]))) / job_info->grid_size[0]; - unsigned grid_x = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1])) - (grid_y * job_info->grid_size[0])); + thread_data.payload = job_info->payload; + + unsigned grid_z, grid_y, grid_x; + + if (job_info->use_iters) { + grid_z = iter_idx / (job_info->iter_size[0] * job_info->iter_size[1]); + grid_y = (iter_idx - (grid_z * (job_info->iter_size[0] * job_info->iter_size[1]))) / job_info->iter_size[0]; + grid_x = (iter_idx - (grid_z * (job_info->iter_size[0] * job_info->iter_size[1])) - (grid_y * job_info->iter_size[0])); + } else { + grid_z = iter_idx / (job_info->grid_size[0] * job_info->grid_size[1]); + grid_y = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1]))) / job_info->grid_size[0]; + grid_x = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1])) - (grid_y * job_info->grid_size[0])); + } grid_z += job_info->grid_base[2]; grid_y += job_info->grid_base[1]; grid_x += job_info->grid_base[0]; struct lp_compute_shader_variant *variant = job_info->current->variant; + + void *io_ptr = NULL; + if (job_info->io) { + size_t io_offset = job_info->io_stride * iter_idx; + io_ptr = (char *)job_info->io + io_offset; + } + if (thread_data.payload) { + size_t payload_offset = job_info->payload_stride * iter_idx; + thread_data.payload = (char *)thread_data.payload + payload_offset; + } variant->jit_function(&job_info->current->jit_context, &job_info->current->jit_resources, job_info->block_size[0], job_info->block_size[1], job_info->block_size[2], grid_x, grid_y, grid_z, - job_info->grid_size[0], job_info->grid_size[1], job_info->grid_size[2], job_info->work_dim, + job_info->grid_size[0], job_info->grid_size[1], job_info->grid_size[2], + job_info->work_dim, job_info->draw_id, + io_ptr, &thread_data); } diff --git a/src/gallium/drivers/llvmpipe/lp_state_cs.h b/src/gallium/drivers/llvmpipe/lp_state_cs.h index 45c0aa8..855f195 100644 --- a/src/gallium/drivers/llvmpipe/lp_state_cs.h +++ b/src/gallium/drivers/llvmpipe/lp_state_cs.h @@ -86,6 +86,10 @@ struct lp_compute_shader_variant LLVMTypeRef jit_resources_ptr_type; LLVMTypeRef jit_cs_thread_data_ptr_type; + /* for mesh shaders */ + LLVMTypeRef jit_vertex_header_type; + LLVMTypeRef jit_vertex_header_ptr_type; + LLVMTypeRef jit_prim_type; LLVMValueRef function; lp_jit_cs_func jit_function; -- 2.7.4