gallivm: add support for function calling
authorDave Airlie <airlied@redhat.com>
Tue, 15 Aug 2023 05:54:19 +0000 (15:54 +1000)
committerMarge Bot <emma+marge@anholt.net>
Tue, 12 Sep 2023 01:57:50 +0000 (01:57 +0000)
This adds support for calling functions in compute shaders.

Functions are passed two implicit arguments
- the current exec mask
- a context containing all the info needed for intrinsics to work
  when not in the toplevel.

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24687>

src/gallium/auxiliary/gallivm/lp_bld_nir.c
src/gallium/auxiliary/gallivm/lp_bld_nir.h
src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
src/gallium/auxiliary/gallivm/lp_bld_tgsi.h

index aab7c80..ce49783 100644 (file)
@@ -2100,6 +2100,20 @@ visit_payload_atomic(struct lp_build_nir_context *bld_base,
                         offset, val, val2, &result[0]);
 }
 
+static void visit_load_param(struct lp_build_nir_context *bld_base,
+                             nir_intrinsic_instr *instr,
+                             LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
+{
+   LLVMValueRef param = LLVMGetParam(bld_base->func, nir_intrinsic_param_idx(instr) + LP_RESV_FUNC_ARGS);
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   if (instr->num_components == 1)
+      result[0] = param;
+   else {
+      for (unsigned i = 0; i < instr->num_components; i++)
+         result[i] = LLVMBuildExtractValue(gallivm->builder, param, i, "");
+   }
+}
+
 static void
 visit_intrinsic(struct lp_build_nir_context *bld_base,
                 nir_intrinsic_instr *instr)
@@ -2305,6 +2319,9 @@ visit_intrinsic(struct lp_build_nir_context *bld_base,
                                                get_src(bld_base, instr->src[0]),
                                                get_src(bld_base, instr->src[1]));
       break;
+   case nir_intrinsic_load_param:
+      visit_load_param(bld_base, instr, result);
+      break;
    default:
       fprintf(stderr, "Unsupported intrinsic: ");
       nir_print_instr(&instr->instr, stderr);
@@ -2729,6 +2746,29 @@ visit_deref(struct lp_build_nir_context *bld_base,
    assign_ssa(bld_base, instr->def.index, result);
 }
 
+static void
+visit_call(struct lp_build_nir_context *bld_base,
+           nir_call_instr *instr)
+{
+   LLVMValueRef *args;
+   struct hash_entry *entry = _mesa_hash_table_search(bld_base->fns, instr->callee);
+   struct lp_build_fn *fn = entry->data;
+   args = calloc(instr->num_params + LP_RESV_FUNC_ARGS, sizeof(LLVMValueRef));
+
+   assert(args);
+
+   args[0] = 0;
+   for (unsigned i = 0; i < instr->num_params; i++) {
+      LLVMValueRef arg = get_src(bld_base, instr->params[i]);
+
+      if (nir_src_bit_size(instr->params[i]) == 32 && LLVMTypeOf(arg) == bld_base->base.vec_type)
+         arg = cast_type(bld_base, arg, nir_type_int, 32);
+      args[i + LP_RESV_FUNC_ARGS] = arg;
+   }
+
+   bld_base->call(bld_base, fn, instr->num_params + LP_RESV_FUNC_ARGS, args);
+   free(args);
+}
 
 static void
 visit_block(struct lp_build_nir_context *bld_base, nir_block *block)
@@ -2760,6 +2800,9 @@ visit_block(struct lp_build_nir_context *bld_base, nir_block *block)
       case nir_instr_type_deref:
          visit_deref(bld_base, nir_instr_as_deref(instr));
          break;
+      case nir_instr_type_call:
+         visit_call(bld_base, nir_instr_as_call(instr));
+         break;
       default:
          fprintf(stderr, "Unknown NIR instr type: ");
          nir_print_instr(instr, stderr);
index 14e77d7..3ef67ae 100644 (file)
 
 struct nir_shader;
 
+/*
+ * 2 reserved functions args for each function call,
+ * exec mask and context.
+ */
+#define LP_RESV_FUNC_ARGS 2
+
 void lp_build_nir_soa(struct gallivm_state *gallivm,
                       struct nir_shader *shader,
                       const struct lp_build_tgsi_params *params,
@@ -55,6 +61,11 @@ void lp_build_nir_aos(struct gallivm_state *gallivm,
                       LLVMValueRef *outputs,
                       const struct lp_build_sampler_aos *sampler);
 
+struct lp_build_fn {
+   LLVMTypeRef fn_type;
+   LLVMValueRef fn;
+};
+
 struct lp_build_nir_context
 {
    struct lp_build_context base;
@@ -72,12 +83,14 @@ struct lp_build_nir_context
    LLVMValueRef *ssa_defs;
    struct hash_table *regs;
    struct hash_table *vars;
+   struct hash_table *fns;
 
    /** Value range analysis hash table used in code generation. */
    struct hash_table *range_ht;
 
    LLVMValueRef aniso_filter_table;
 
+   LLVMValueRef func;
    nir_shader *shader;
 
    void (*load_ubo)(struct lp_build_nir_context *bld_base,
@@ -243,6 +256,11 @@ struct lp_build_nir_context
                                                LLVMValueRef prim_count);
    void (*launch_mesh_workgroups)(struct lp_build_nir_context *bld_base,
                                   LLVMValueRef launch_grid);
+
+   void (*call)(struct lp_build_nir_context *bld_base,
+                struct lp_build_fn *fn,
+                int num_args,
+                LLVMValueRef *args);
 //   LLVMValueRef main_function
 };
 
@@ -299,6 +317,9 @@ struct lp_build_nir_soa_context
 
    LLVMValueRef kernel_args_ptr;
    unsigned gs_vertex_streams;
+
+   LLVMTypeRef call_context_type;
+   LLVMValueRef call_context_ptr;
 };
 
 void
@@ -389,5 +410,31 @@ lp_build_nir_sample_key(gl_shader_stage stage, nir_tex_instr *instr);
 
 void lp_img_op_from_intrinsic(struct lp_img_params *params, nir_intrinsic_instr *instr);
 
+enum lp_nir_call_context_args {
+   LP_NIR_CALL_CONTEXT_CONTEXT,
+   LP_NIR_CALL_CONTEXT_RESOURCES,
+   LP_NIR_CALL_CONTEXT_SHARED,
+   LP_NIR_CALL_CONTEXT_SCRATCH,
+   LP_NIR_CALL_CONTEXT_WORK_DIM,
+   LP_NIR_CALL_CONTEXT_THREAD_ID_0,
+   LP_NIR_CALL_CONTEXT_THREAD_ID_1,
+   LP_NIR_CALL_CONTEXT_THREAD_ID_2,
+   LP_NIR_CALL_CONTEXT_BLOCK_ID_0,
+   LP_NIR_CALL_CONTEXT_BLOCK_ID_1,
+   LP_NIR_CALL_CONTEXT_BLOCK_ID_2,
+   LP_NIR_CALL_CONTEXT_GRID_SIZE_0,
+   LP_NIR_CALL_CONTEXT_GRID_SIZE_1,
+   LP_NIR_CALL_CONTEXT_GRID_SIZE_2,
+   LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0,
+   LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1,
+   LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2,
+   LP_NIR_CALL_CONTEXT_MAX_ARGS,
+};
+
+LLVMTypeRef
+lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length,
+                              LLVMTypeRef context_type, LLVMTypeRef resources_type);
+
+
 
 #endif
index baa533e..06dde5c 100644 (file)
@@ -2694,6 +2694,19 @@ emit_launch_mesh_workgroups(struct lp_build_nir_context *bld_base,
    lp_build_endif(&ifthen);
 }
 
+static void
+emit_call(struct lp_build_nir_context *bld_base,
+          struct lp_build_fn *fn,
+          int num_args,
+          LLVMValueRef *args)
+{
+   struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
+
+   args[0] = mask_vec(bld_base);
+   args[1] = bld->call_context_ptr;
+   LLVMBuildCall2(bld_base->base.gallivm->builder, fn->fn_type, fn->fn, args, num_args, "");
+}
+
 static LLVMValueRef get_scratch_thread_offsets(struct gallivm_state *gallivm,
                                                struct lp_type type,
                                                unsigned scratch_size)
@@ -2800,6 +2813,90 @@ emit_clock(struct lp_build_nir_context *bld_base,
    dst[1] = lp_build_broadcast_scalar(uint_bld, hi);
 }
 
+LLVMTypeRef
+lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length,
+                              LLVMTypeRef context_type, LLVMTypeRef resources_type)
+{
+   LLVMTypeRef args[LP_NIR_CALL_CONTEXT_MAX_ARGS];
+
+   args[LP_NIR_CALL_CONTEXT_CONTEXT] = LLVMPointerType(context_type, 0);
+   args[LP_NIR_CALL_CONTEXT_RESOURCES] = LLVMPointerType(resources_type, 0);
+   args[LP_NIR_CALL_CONTEXT_SHARED] = LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0); /* shared_ptr */
+   args[LP_NIR_CALL_CONTEXT_SCRATCH] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* scratch ptr */
+   args[LP_NIR_CALL_CONTEXT_WORK_DIM] = LLVMInt32TypeInContext(gallivm->context); /* work_dim */
+   args[LP_NIR_CALL_CONTEXT_THREAD_ID_0] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[0] */
+   args[LP_NIR_CALL_CONTEXT_THREAD_ID_1] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[1] */
+   args[LP_NIR_CALL_CONTEXT_THREAD_ID_2] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[2] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_ID_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[0] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_ID_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[1] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_ID_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[2] */
+
+   args[LP_NIR_CALL_CONTEXT_GRID_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[0] */
+   args[LP_NIR_CALL_CONTEXT_GRID_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[1] */
+   args[LP_NIR_CALL_CONTEXT_GRID_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[2] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[0] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[1] */
+   args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[2] */
+
+   LLVMTypeRef stype = LLVMStructTypeInContext(gallivm->context, args, LP_NIR_CALL_CONTEXT_MAX_ARGS, 0);
+   return stype;
+}
+
+static void
+build_call_context(struct lp_build_nir_soa_context *bld)
+{
+   struct gallivm_state *gallivm = bld->bld_base.base.gallivm;
+   bld->call_context_ptr = lp_build_alloca(gallivm, bld->call_context_type, "callcontext");
+   LLVMValueRef call_context = LLVMGetUndef(bld->call_context_type);
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->context_ptr, LP_NIR_CALL_CONTEXT_CONTEXT, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->resources_ptr, LP_NIR_CALL_CONTEXT_RESOURCES, "");
+   if (bld->shared_ptr) {
+      call_context = LLVMBuildInsertValue(gallivm->builder,
+                                          call_context, bld->shared_ptr, LP_NIR_CALL_CONTEXT_SHARED, "");
+   } else {
+      call_context = LLVMBuildInsertValue(gallivm->builder, call_context,
+                                          LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)),
+                                          LP_NIR_CALL_CONTEXT_SHARED, "");
+   }
+   if (bld->scratch_ptr) {
+      call_context = LLVMBuildInsertValue(gallivm->builder,
+                                          call_context, bld->scratch_ptr, LP_NIR_CALL_CONTEXT_SCRATCH, "");
+   } else {
+      call_context = LLVMBuildInsertValue(gallivm->builder, call_context,
+                                          LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)),
+                                          LP_NIR_CALL_CONTEXT_SCRATCH, "");
+   }
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.work_dim, LP_NIR_CALL_CONTEXT_WORK_DIM, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.thread_id[0], LP_NIR_CALL_CONTEXT_THREAD_ID_0, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.thread_id[1], LP_NIR_CALL_CONTEXT_THREAD_ID_1, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.thread_id[2], LP_NIR_CALL_CONTEXT_THREAD_ID_2, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_id[0], LP_NIR_CALL_CONTEXT_BLOCK_ID_0, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_id[1], LP_NIR_CALL_CONTEXT_BLOCK_ID_1, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_id[2], LP_NIR_CALL_CONTEXT_BLOCK_ID_2, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.grid_size[0], LP_NIR_CALL_CONTEXT_GRID_SIZE_0, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.grid_size[1], LP_NIR_CALL_CONTEXT_GRID_SIZE_1, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.grid_size[2], LP_NIR_CALL_CONTEXT_GRID_SIZE_2, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_size[0], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_size[1], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, "");
+   call_context = LLVMBuildInsertValue(gallivm->builder,
+                                       call_context, bld->system_values.block_size[2], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, "");
+   LLVMBuildStore(gallivm->builder, call_context, bld->call_context_ptr);
+}
+
 void lp_build_nir_soa_func(struct gallivm_state *gallivm,
                            struct nir_shader *shader,
                            nir_function_impl *impl,
@@ -2911,6 +3008,7 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
    bld.bld_base.read_invocation = emit_read_invocation;
    bld.bld_base.helper_invocation = emit_helper_invocation;
    bld.bld_base.interp_at = emit_interp_at;
+   bld.bld_base.call = emit_call;
    bld.bld_base.load_scratch = emit_load_scratch;
    bld.bld_base.store_scratch = emit_store_scratch;
    bld.bld_base.load_const = emit_load_const;
@@ -2918,6 +3016,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
    bld.bld_base.set_vertex_and_primitive_count = emit_set_vertex_and_primitive_count;
    bld.bld_base.launch_mesh_workgroups = emit_launch_mesh_workgroups;
 
+   bld.bld_base.fns = params->fns;
+   bld.bld_base.func = params->current_func;
    bld.mask = params->mask;
    bld.inputs = params->inputs;
    bld.outputs = outputs;
@@ -2925,6 +3025,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
    bld.ssbo_ptr = params->ssbo_ptr;
    bld.sampler = params->sampler;
 
+   bld.context_type = params->context_type;
+   bld.context_ptr = params->context_ptr;
    bld.resources_type = params->resources_type;
    bld.resources_ptr = params->resources_ptr;
    bld.thread_data_type = params->thread_data_type;
@@ -2961,18 +3063,29 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
    }
    lp_exec_mask_init(&bld.exec_mask, &bld.bld_base.int_bld);
 
-   bld.system_values = *params->system_values;
+   if (params->system_values)
+      bld.system_values = *params->system_values;
 
    bld.bld_base.shader = shader;
 
    bld.scratch_size = ALIGN(shader->scratch_size, 8);
-   if (shader->scratch_size) {
+   if (params->scratch_ptr)
+      bld.scratch_ptr = params->scratch_ptr;
+   else if (shader->scratch_size) {
       bld.scratch_ptr = lp_build_array_alloca(gallivm,
                                               LLVMInt8TypeInContext(gallivm->context),
                                               lp_build_const_int32(gallivm, bld.scratch_size * type.length),
                                               "scratch");
    }
 
+   if (shader->info.stage == MESA_SHADER_KERNEL) {
+      bld.call_context_type = lp_build_cs_func_call_context(gallivm, type.length, bld.context_type, bld.resources_type);
+      if (!params->call_context_ptr) {
+         build_call_context(&bld);
+      } else
+         bld.call_context_ptr = params->call_context_ptr;
+   }
+
    emit_prologue(&bld);
    lp_build_nir_llvm(&bld.bld_base, shader, impl);
 
index a5a049f..a4435e2 100644 (file)
@@ -289,6 +289,10 @@ struct lp_build_tgsi_params {
    const struct lp_build_fs_iface *fs_iface;
    unsigned gs_vertex_streams;
    LLVMValueRef aniso_filter_table;
+   LLVMValueRef current_func;
+   struct hash_table *fns;
+   LLVMValueRef scratch_ptr;
+   LLVMValueRef call_context_ptr;
 };
 
 void