From 734f3388a339c7ad110dd6da33d4b3da0c0bdebf Mon Sep 17 00:00:00 2001 From: Alexander Kyte Date: Wed, 24 Apr 2019 10:53:47 -0400 Subject: [PATCH] [llvm] Propagate nonnull attribute through LLVM IR (mono/mono#13697) ## Summary This change allows LLVM to identify unnecessary null checks. We mark GOT accesses (including those made by LDSTR) and new object calls as nonnull. Since LLVM won't propagate this, we propagate from definition to usage, across casts, and from usage to definition (conditionally). This enables it to remove a significant portion of some benchmarks. In real-world code, we can expect to see this remove the unnecessary null checks made by private methods. ## Example ### C#: Note that the constant strings are trivially non-null. This PR spots and propagates that. ``` static void ThrowIfNull(string s) { if (s == null) ThrowArgumentNullException(); } static void ThrowArgumentNullException() { throw new ArgumentNullException(); } [MethodImpl(MethodImplOptions.NoInlining)] static int Bench(string a, string b, string c, string d) { ThrowIfNull(a); ThrowIfNull(b); ThrowIfNull(c); ThrowIfNull(d); return a.Length + b.Length + c.Length + d.Length; } [Benchmark(Description = nameof(NoThrowInline))] public int Test() => Bench("a", "bc", "def", "ghij"); ``` ### Before: ``` define hidden monocc i32 @NoThrowInline_MainClass_Bench_string_string_string_string(i64* %arg_a, i64* %arg_b, i64* %arg_c, i64* %arg_d) mono/mono#6 gc "mono" { BB0: br label %INIT_BB1 INIT_BB1: ; preds = %BB0 br label %INITED_BB2 INITED_BB2: ; preds = %INIT_BB1 br label %BB3 BB3: ; preds = %INITED_BB2 br label %BB2 BB2: ; preds = %BB3 notail call monocc void @NoThrowInline_MainClass_ThrowIfNull_string(i64* %arg_a) notail call monocc void @NoThrowInline_MainClass_ThrowIfNull_string(i64* %arg_b) notail call monocc void @NoThrowInline_MainClass_ThrowIfNull_string(i64* %arg_c) notail call monocc void @NoThrowInline_MainClass_ThrowIfNull_string(i64* %arg_d) %0 = bitcast i64* %arg_a to i32* %1 = getelementptr i32, i32* %0, i32 4 %t50 = load volatile i32, i32* %1 %2 = bitcast i64* %arg_b to i32* %3 = getelementptr i32, i32* %2, i32 4 %t52 = load volatile i32, i32* %3 %t53 = add i32 %t50, %t52 %4 = bitcast i64* %arg_c to i32* %5 = getelementptr i32, i32* %4, i32 4 %t55 = load volatile i32, i32* %5 %t56 = add i32 %t53, %t55 %6 = bitcast i64* %arg_d to i32* %7 = getelementptr i32, i32* %6, i32 4 %t58 = load volatile i32, i32* %7 %t60 = add i32 %t56, %t58 br label %BB1 BB1: ; preds = %BB2 ret i32 %t60 } ``` ### After: Note: safepoint in below code is added by backend, not part of this change ``` define hidden monocc i32 @NoThrowInline_MainClass_Bench_string_string_string_string(i64* nonnull %arg_a, i64* nonnull %arg_b, i64* nonnull %arg_c, i64* nonnull %arg_d) mono/mono#6 gc "mono" { BB0: %0 = getelementptr i64, i64* %arg_a, i64 2 %1 = bitcast i64* %0 to i32* %t50 = load volatile i32, i32* %1, align 4 %2 = getelementptr i64, i64* %arg_b, i64 2 %3 = bitcast i64* %2 to i32* %t52 = load volatile i32, i32* %3, align 4 %t53 = add i32 %t52, %t50 %4 = getelementptr i64, i64* %arg_c, i64 2 %5 = bitcast i64* %4 to i32* %t55 = load volatile i32, i32* %5, align 4 %t56 = add i32 %t53, %t55 %6 = getelementptr i64, i64* %arg_d, i64 2 %7 = bitcast i64* %6 to i32* %t58 = load volatile i32, i32* %7, align 4 %t60 = add i32 %t56, %t58 %8 = load i64*, i64** getelementptr inbounds ([37 x i64*], [37 x i64*]* @mono_aot_NoThrowInline_llvm_got, i64 0, i64 7), align 8 %9 = load i64, i64* %8, align 4 %10 = icmp eq i64 %9, 0 br i1 %10, label %gc.safepoint_poll.exit, label %gc.safepoint_poll.poll.i gc.safepoint_poll.poll.i: ; preds = %BB0 %11 = load void ()*, void ()** bitcast (i64** getelementptr inbounds ([37 x i64*], [37 x i64*]* @mono_aot_NoThrowInline_llvm_got, i64 0, i64 25) to void ()**), align 8 call void %11() mono/mono#8 br label %gc.safepoint_poll.exit gc.safepoint_poll.exit: ; preds = %BB0, %gc.safepoint_poll.poll.i ret i32 %t60 } ``` ## Dependencies This depends on https://github.com/mono/linker/pull/528. Commit migrated from https://github.com/mono/mono/commit/0d9e13f983ffa82e8e2360072dc67a4ee3989324 --- src/mono/mono/mini/aot-compiler.c | 73 +++++++++++ src/mono/mono/mini/aot-compiler.h | 1 + src/mono/mono/mini/mini-llvm-cpp.cpp | 127 ++++++++++++++++++- src/mono/mono/mini/mini-llvm-cpp.h | 18 +++ src/mono/mono/mini/mini-llvm.c | 233 ++++++++++++++++++++++++++++++++++- 5 files changed, 444 insertions(+), 8 deletions(-) diff --git a/src/mono/mono/mini/aot-compiler.c b/src/mono/mono/mini/aot-compiler.c index f9e045a..fe5dc7e 100644 --- a/src/mono/mono/mini/aot-compiler.c +++ b/src/mono/mono/mini/aot-compiler.c @@ -4353,6 +4353,79 @@ add_gc_wrappers (MonoAotCompile *acfg) } } +static gboolean +contains_disable_reflection_attribute (MonoCustomAttrInfo *cattr) +{ + for (int i = 0; i < cattr->num_attrs; ++i) { + MonoCustomAttrEntry *attr = &cattr->attrs [i]; + + if (!attr->ctor) + return FALSE; + + if (strcmp (m_class_get_name_space (attr->ctor->klass), "System.Runtime.CompilerServices")) + return FALSE; + + if (strcmp (m_class_get_name (attr->ctor->klass), "DisablePrivateReflectionAttribute")) + return FALSE; + } + + return TRUE; +} + +gboolean +mono_aot_can_specialize (MonoMethod *method) +{ + if (!method) + return FALSE; + + if (method->wrapper_type != MONO_WRAPPER_NONE) + return FALSE; + + // If it's not private, we can't specialize + if ((method->flags & METHOD_ATTRIBUTE_MEMBER_ACCESS_MASK) != METHOD_ATTRIBUTE_PRIVATE) + return FALSE; + + // If it has the attribute disabling the specialization, we can't specialize + // + // Set by linker, indicates that the method can be found through reflection + // and that call-site specialization shouldn't be done. + // + // Important that this attribute is used for *nothing else* + // + // If future authors make use of it (to disable more optimizations), + // change this place to use a new attribute. + ERROR_DECL (cattr_error); + MonoCustomAttrInfo *cattr = mono_custom_attrs_from_class_checked (method->klass, cattr_error); + + if (!is_ok (cattr_error)) { + mono_error_cleanup (cattr_error); + goto cleanup_false; + } else if (cattr && contains_disable_reflection_attribute (cattr)) { + goto cleanup_true; + } + + cattr = mono_custom_attrs_from_method_checked (method, cattr_error); + + if (!is_ok (cattr_error)) { + mono_error_cleanup (cattr_error); + goto cleanup_false; + } else if (cattr && contains_disable_reflection_attribute (cattr)) { + goto cleanup_true; + } else { + goto cleanup_false; + } + +cleanup_false: + if (cattr) + mono_custom_attrs_free (cattr); + return FALSE; + +cleanup_true: + if (cattr) + mono_custom_attrs_free (cattr); + return TRUE; +} + static void add_wrappers (MonoAotCompile *acfg) { diff --git a/src/mono/mono/mini/aot-compiler.h b/src/mono/mono/mini/aot-compiler.h index e5fd43f..6b8803b 100644 --- a/src/mono/mono/mini/aot-compiler.h +++ b/src/mono/mono/mini/aot-compiler.h @@ -22,6 +22,7 @@ char* mono_aot_get_plt_symbol (MonoJumpInfoType type, gconstpointe char* mono_aot_get_direct_call_symbol (MonoJumpInfoType type, gconstpointer data) MONO_LLVM_INTERNAL; int mono_aot_get_method_index (MonoMethod *method) MONO_LLVM_INTERNAL; MonoJumpInfo* mono_aot_patch_info_dup (MonoJumpInfo* ji) MONO_LLVM_INTERNAL; +gboolean mono_aot_can_specialize (MonoMethod *method) MONO_LLVM_INTERNAL; #endif diff --git a/src/mono/mono/mini/mini-llvm-cpp.cpp b/src/mono/mono/mini/mini-llvm-cpp.cpp index dcdc5ed..0b3065e 100644 --- a/src/mono/mono/mini/mini-llvm-cpp.cpp +++ b/src/mono/mono/mini/mini-llvm-cpp.cpp @@ -241,10 +241,133 @@ mono_llvm_set_preserveall_cc (LLVMValueRef func) unwrap(func)->setCallingConv (CallingConv::PreserveAll); } +// Note that in future versions of LLVM, CallInst and InvokeInst +// share a CallBase parent class that would make the below methods +// look much better + void -mono_llvm_set_call_preserveall_cc (LLVMValueRef func) +mono_llvm_set_call_preserveall_cc (LLVMValueRef wrapped_calli) { - unwrap(func)->setCallingConv (CallingConv::PreserveAll); +#if LLVM_API_VERSION > 100 + Instruction *calli = unwrap (wrapped_calli); + + if (isa (calli)) + dyn_cast(calli)->setCallingConv (CallingConv::PreserveAll); + else + dyn_cast(calli)->setCallingConv (CallingConv::PreserveAll); +#else + unwrap(wrapped_calli)->setCallingConv (CallingConv::PreserveAll); +#endif +} + +void +mono_llvm_set_call_nonnull_arg (LLVMValueRef wrapped_calli, int argNo) +{ +#if LLVM_API_VERSION > 100 + Instruction *calli = unwrap (wrapped_calli); + + if (isa (calli)) + dyn_cast(calli)->addParamAttr (argNo, Attribute::NonNull); + else + dyn_cast(calli)->addParamAttr (argNo, Attribute::NonNull); +#endif +} + +void +mono_llvm_set_call_nonnull_ret (LLVMValueRef wrapped_calli) +{ +#if LLVM_API_VERSION > 100 + Instruction *calli = unwrap (wrapped_calli); + + if (isa (calli)) + dyn_cast(calli)->addAttribute (AttributeList::ReturnIndex, Attribute::NonNull); + else + dyn_cast(calli)->addAttribute (AttributeList::ReturnIndex, Attribute::NonNull); +#endif +} + +void +mono_llvm_set_func_nonnull_arg (LLVMValueRef func, int argNo) +{ +#if LLVM_API_VERSION > 100 + unwrap(func)->addParamAttr (argNo, Attribute::NonNull); +#endif +} + +gboolean +mono_llvm_is_nonnull (LLVMValueRef wrapped) +{ +#if LLVM_API_VERSION > 100 + // Argument to function + Value *val = unwrap (wrapped); + + while (val) { + if (Argument *arg = dyn_cast (val)) { + return arg->hasNonNullAttr (); + } else if (CallInst *calli = dyn_cast (val)) { + return calli->hasRetAttr (Attribute::NonNull); + } else if (InvokeInst *calli = dyn_cast (val)) { + return calli->hasRetAttr (Attribute::NonNull); + } else if (LoadInst *loadi = dyn_cast (val)) { + return loadi->getMetadata("nonnull") != nullptr; // nonnull + } else if (Instruction *inst = dyn_cast (val)) { + // If not a load or a function argument, the only case for us to + // consider is that it's a bitcast. If so, recurse on what was casted. + if (inst->getOpcode () == LLVMBitCast) { + val = inst->getOperand (0); + continue; + } + + return FALSE; + } else { + return FALSE; + } + } + +#endif + return FALSE; +} + +GSList * +mono_llvm_calls_using (LLVMValueRef wrapped_local) +{ + GSList *usages = NULL; + Value *local = unwrap (wrapped_local); + + for (User *user : local->users ()) { + if (isa (user) || isa (user)) { + usages = g_slist_prepend (usages, wrap (user)); + } + } + + return usages; +} + +LLVMValueRef * +mono_llvm_call_args (LLVMValueRef wrapped_calli) +{ + Value *calli = unwrap(wrapped_calli); + CallInst *call = dyn_cast (calli); + InvokeInst *invoke = dyn_cast (calli); + g_assert (call || invoke); + + unsigned int numOperands = 0; + + if (call) + numOperands = call->getNumArgOperands (); + else + numOperands = invoke->getNumArgOperands (); + + LLVMValueRef *ret = g_malloc (sizeof (LLVMValueRef) * numOperands); + + for (int i=0; i < numOperands; i++) { + if (call) + ret [i] = wrap (call->getArgOperand (i)); + else + ret [i] = wrap (invoke->getArgOperand (i)); + } + + return ret; } void diff --git a/src/mono/mono/mini/mini-llvm-cpp.h b/src/mono/mono/mini/mini-llvm-cpp.h index 46a123f..60607f3 100644 --- a/src/mono/mono/mini/mini-llvm-cpp.h +++ b/src/mono/mono/mini/mini-llvm-cpp.h @@ -104,6 +104,24 @@ void mono_llvm_set_call_preserveall_cc (LLVMValueRef call); void +mono_llvm_set_call_nonnull_arg (LLVMValueRef calli, int argNo); + +void +mono_llvm_set_call_nonnull_ret (LLVMValueRef calli); + +void +mono_llvm_set_func_nonnull_arg (LLVMValueRef func, int argNo); + +GSList * +mono_llvm_calls_using (LLVMValueRef wrapped_local); + +LLVMValueRef * +mono_llvm_call_args (LLVMValueRef calli); + +gboolean +mono_llvm_is_nonnull (LLVMValueRef val); + +void mono_llvm_set_call_notailcall (LLVMValueRef call); void diff --git a/src/mono/mono/mini/mini-llvm.c b/src/mono/mono/mini/mini-llvm.c index da2ffcb..8152f49 100644 --- a/src/mono/mono/mini/mini-llvm.c +++ b/src/mono/mono/mini/mini-llvm.c @@ -65,6 +65,8 @@ typedef struct { GHashTable *plt_entries; GHashTable *plt_entries_ji; GHashTable *method_to_lmethod; + GHashTable *method_to_call_info; + GHashTable *lvalue_to_lcalls; GHashTable *direct_callables; char **bb_names; int bb_names_len; @@ -371,6 +373,8 @@ static void emit_cond_system_exception (EmitContext *ctx, MonoBasicBlock *bb, co static LLVMValueRef get_intrins_by_name (EmitContext *ctx, const char *name); static LLVMValueRef get_intrins (EmitContext *ctx, int id); static void llvm_jit_finalize_method (EmitContext *ctx); +static void mono_llvm_nonnull_state_update (EmitContext *ctx, LLVMValueRef lcall, MonoMethod *call_method, LLVMValueRef *args, int num_params); +static void mono_llvm_propagate_nonnull_final (GHashTable *all_specializable, MonoLLVMModule *module); static inline void set_failure (EmitContext *ctx, const char *message) @@ -2055,6 +2059,19 @@ set_metadata_flag (LLVMValueRef v, const char *flag_name) } static void +set_nonnull_load_flag (LLVMValueRef v) +{ + LLVMValueRef md_arg; + int md_kind; + const char *flag_name; + + flag_name = "nonnull"; + md_kind = LLVMGetMDKindID (flag_name, strlen (flag_name)); + md_arg = LLVMMDString ("", strlen ("")); + LLVMSetMetadata (v, md_kind, LLVMMDNode (&md_arg, 1)); +} + +static void set_invariant_load_flag (LLVMValueRef v) { LLVMValueRef md_arg; @@ -4069,6 +4086,13 @@ process_call (EmitContext *ctx, MonoBasicBlock *bb, LLVMBuilderRef *builder_ref, */ lcall = emit_call (ctx, bb, &builder, callee, args, LLVMCountParamTypes (llvm_sig)); + mono_llvm_nonnull_state_update (ctx, lcall, call->method, args, LLVMCountParamTypes (llvm_sig)); + + // If we just allocated an object, it's not null. + if (call->method && call->method->wrapper_type == MONO_WRAPPER_ALLOC) { + mono_llvm_set_call_nonnull_ret (lcall); + } + if (ins->opcode != OP_TAILCALL && ins->opcode != OP_TAILCALL_MEMBASE && LLVMGetInstructionOpcode (lcall) == LLVMCall) mono_llvm_set_call_notailcall (lcall); @@ -5042,10 +5066,21 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb) } else if (ins->opcode == OP_RCOMPARE) { cmp = LLVMBuildFCmp (builder, fpcond_to_llvm_cond [rel], convert (ctx, lhs, LLVMFloatType ()), convert (ctx, rhs, LLVMFloatType ()), ""); } else if (ins->opcode == OP_COMPARE_IMM) { - if (LLVMGetTypeKind (LLVMTypeOf (lhs)) == LLVMPointerTypeKind && ins->inst_imm == 0) - cmp = LLVMBuildICmp (builder, cond_to_llvm_cond [rel], lhs, LLVMConstNull (LLVMTypeOf (lhs)), ""); - else - cmp = LLVMBuildICmp (builder, cond_to_llvm_cond [rel], convert (ctx, lhs, IntPtrType ()), LLVMConstInt (IntPtrType (), ins->inst_imm, FALSE), ""); + LLVMIntPredicate llvm_pred = cond_to_llvm_cond [rel]; + if (LLVMGetTypeKind (LLVMTypeOf (lhs)) == LLVMPointerTypeKind && ins->inst_imm == 0) { + // We are emitting a NULL check for a pointer + gboolean nonnull = mono_llvm_is_nonnull (lhs); + + if (nonnull && llvm_pred == LLVMIntEQ) + cmp = LLVMConstInt (LLVMInt1Type (), FALSE, FALSE); + else if (nonnull && llvm_pred == LLVMIntNE) + cmp = LLVMConstInt (LLVMInt1Type (), TRUE, FALSE); + else + cmp = LLVMBuildICmp (builder, llvm_pred, lhs, LLVMConstNull (LLVMTypeOf (lhs)), ""); + + } else { + cmp = LLVMBuildICmp (builder, llvm_pred, convert (ctx, lhs, IntPtrType ()), LLVMConstInt (IntPtrType (), ins->inst_imm, FALSE), ""); + } } else if (ins->opcode == OP_LCOMPARE_IMM) { cmp = LLVMBuildICmp (builder, cond_to_llvm_cond [rel], lhs, rhs, ""); } @@ -5876,8 +5911,13 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb) values [ins->dreg] = LLVMBuildLoad (builder, got_entry_addr, name); g_free (name); /* Can't use this in llvmonly mode since the got slots are initialized by the methods themselves */ - if (!cfg->llvm_only || mono_aot_is_shared_got_offset (got_offset)) + if (!cfg->llvm_only || mono_aot_is_shared_got_offset (got_offset)) { + /* Can't use this in llvmonly mode since the got slots are initialized by the methods themselves */ set_invariant_load_flag (values [ins->dreg]); + } + + if (ji->type == MONO_PATCH_INFO_LDSTR) + set_nonnull_load_flag (values [ins->dreg]); break; } case OP_NOT_REACHED: @@ -8121,6 +8161,7 @@ after_codegen: if (ctx->module->method_to_lmethod) g_hash_table_insert (ctx->module->method_to_lmethod, cfg->method, ctx->lmethod); + if (ctx->module->idx_to_lmethod) g_hash_table_insert (ctx->module->idx_to_lmethod, GINT_TO_POINTER (cfg->method_index), ctx->lmethod); @@ -8971,8 +9012,9 @@ mono_llvm_create_aot_module (MonoAssembly *assembly, const char *global_prefix, module->plt_entries = g_hash_table_new (g_str_hash, g_str_equal); module->plt_entries_ji = g_hash_table_new (NULL, NULL); module->direct_callables = g_hash_table_new (g_str_hash, g_str_equal); - module->method_to_lmethod = g_hash_table_new (NULL, NULL); module->idx_to_lmethod = g_hash_table_new (NULL, NULL); + module->method_to_lmethod = g_hash_table_new (NULL, NULL); + module->method_to_call_info = g_hash_table_new (NULL, NULL); module->idx_to_unbox_tramp = g_hash_table_new (NULL, NULL); module->callsite_list = g_ptr_array_new (); } @@ -8993,6 +9035,8 @@ mono_llvm_fixup_aot_module (void) * with llvm. */ + GHashTable *specializable = g_hash_table_new (NULL, NULL); + GHashTable *patches_to_null = g_hash_table_new (mono_patch_info_hash, mono_patch_info_equal); for (int sindex = 0; sindex < module->callsite_list->len; ++sindex) { CallSite *site = (CallSite*)g_ptr_array_index (module->callsite_list, sindex); @@ -9005,6 +9049,10 @@ mono_llvm_fixup_aot_module (void) if (lmethod && !(method->iflags & METHOD_IMPL_ATTRIBUTE_SYNCHRONIZED)) { mono_llvm_replace_uses_of (placeholder, lmethod); + + if (mono_aot_can_specialize (method)) + g_hash_table_insert (specializable, lmethod, method); + g_hash_table_insert (patches_to_null, site->ji, site->ji); } else { int got_offset = compute_aot_got_offset (module, site->ji, site->type); @@ -9025,6 +9073,9 @@ mono_llvm_fixup_aot_module (void) g_free (site); } + mono_llvm_propagate_nonnull_final (specializable, module); + g_hash_table_destroy (specializable); + for (int i = 0; i < module->cfgs->len; ++i) { /* * Nullify the patches pointing to direct calls. This is needed to @@ -9336,6 +9387,167 @@ emit_aot_file_info (MonoLLVMModule *module) } } +typedef struct { + LLVMValueRef lmethod; + int argument; +} NonnullPropWorkItem; + +static void +mono_llvm_nonnull_state_update (EmitContext *ctx, LLVMValueRef lcall, MonoMethod *call_method, LLVMValueRef *args, int num_params) +{ +#if LLVM_API_VERSION > 100 + if (!ctx->module->llvm_disable_self_init && mono_aot_can_specialize (call_method)) { + int num_passed = LLVMGetNumArgOperands (lcall); + g_assert (num_params <= num_passed); + + g_assert (ctx->module->method_to_call_info); + GArray *call_site_union = (GArray *) g_hash_table_lookup (ctx->module->method_to_call_info, call_method); + + if (!call_site_union) { + call_site_union = g_array_sized_new (FALSE, TRUE, sizeof (gint32), num_params); + int zero = 0; + for (int i = 0; i < num_params; i++) + g_array_insert_val (call_site_union, i, zero); + } + + for (int i = 0; i < num_params; i++) { + if (mono_llvm_is_nonnull (args [i])) { + g_assert (i < LLVMGetNumArgOperands (lcall)); + mono_llvm_set_call_nonnull_arg (lcall, i); + } else { + gint32 *nullable_count = &g_array_index (call_site_union, gint32, i); + *nullable_count = *nullable_count + 1; + } + } + + g_hash_table_insert (ctx->module->method_to_call_info, call_method, call_site_union); + } +#endif +} + +static void +mono_llvm_propagate_nonnull_final (GHashTable *all_specializable, MonoLLVMModule *module) +{ +#if LLVM_API_VERSION > 100 + // When we first traverse the mini IL, we mark the things that are + // nonnull (the roots). Then, for all of the methods that can be specialized, we + // see if their call sites have nonnull attributes. + + // If so, we mark the function's param. This param has uses to propagate + // the attribute to. This propagation can trigger a need to mark more attributes + // non-null, and so on and so forth. + GSList *queue = NULL; + + GHashTableIter iter; + LLVMValueRef lmethod; + MonoMethod *method; + g_hash_table_iter_init (&iter, all_specializable); + while (g_hash_table_iter_next (&iter, (void**)&lmethod, (void**)&method)) { + GArray *call_site_union = (GArray *) g_hash_table_lookup (module->method_to_call_info, method); + + // Basic sanity checking + if (call_site_union) + g_assert (call_site_union->len == LLVMCountParams (lmethod)); + + // Add root to work queue + for (int i = 0; call_site_union && i < call_site_union->len; i++) { + if (g_array_index (call_site_union, gint32, i) == 0) { + NonnullPropWorkItem *item = g_malloc (sizeof (NonnullPropWorkItem)); + item->lmethod = lmethod; + item->argument = i; + queue = g_slist_prepend (queue, item); + } + } + } + + // This is essentially reference counting, and we are propagating + // the refcount decrement here. We have less work to do than we may otherwise + // because we are only working with a set of subgraphs of specializable functions. + // + // We rely on being able to see all of the references in the graph. + // This is ensured by the function mono_aot_can_specialize. Everything in + // all_specializable is a function that can be specialized, and is the resulting + // node in the graph after all of the subsitutions are done. + // + // Anything disrupting the direct calls made with self-init will break this optimization. + + while (queue) { + // Update the queue state. + // Our only other per-iteration responsibility is now to free current + NonnullPropWorkItem *current = (NonnullPropWorkItem *) queue->data; + queue = queue->next; + g_assert (current->argument < LLVMCountParams (current->lmethod)); + + // Does the actual leaf-node work here + // Mark the function argument as nonnull for LLVM + mono_llvm_set_func_nonnull_arg (current->lmethod, current->argument); + + // The rest of this is for propagating forward nullability changes + // to calls that use the argument that is now nullable. + + // Get the actual LLVM value of the argument, so we can see which call instructions + // used that argument + LLVMValueRef caller_argument = LLVMGetParam (current->lmethod, current->argument); + + // Iterate over the calls using the newly-non-nullable argument + GSList *calls = mono_llvm_calls_using (caller_argument); + for (GSList *cursor = calls; cursor != NULL; cursor = cursor->next) { + + LLVMValueRef lcall = (LLVMValueRef) cursor->data; + LLVMValueRef callee_lmethod = LLVMGetCalledValue (lcall); + + // If this wasn't a direct call for which mono_aot_can_specialize is true, + // this lookup won't find a MonoMethod. + MonoMethod *callee_method = (MonoMethod *) g_hash_table_lookup (all_specializable, callee_lmethod); + if (!callee_method) + continue; + + // Decrement number of nullable refs at that func's arg offset + GArray *call_site_union = (GArray *) g_hash_table_lookup (module->method_to_call_info, callee_method); + + // It has module-local callers and is specializable, should have seen this call site + // and inited this + g_assert (call_site_union); + + // The function *definition* parameter arity should always be consistent + int max_params = LLVMCountParams (callee_lmethod); + if (call_site_union->len != max_params) { + mono_llvm_dump_value (callee_lmethod); + g_assert_not_reached (); + } + + // Get the values that correspond to the parameters passed to the call + // that used our argument + LLVMValueRef *operands = mono_llvm_call_args (lcall); + for (int call_argument = 0; call_argument < max_params; call_argument++) { + // Every time we used the newly-non-nullable argument, decrement the nullable + // refcount for that function. + if (caller_argument == operands [call_argument]) { + gint32 *nullable_count = &g_array_index (call_site_union, gint32, call_argument); + g_assert (*nullable_count > 0); + *nullable_count = *nullable_count - 1; + + // If we caused that callee's parameter to become newly nullable, add to work queue + if (*nullable_count == 0) { + NonnullPropWorkItem *item = g_malloc (sizeof (NonnullPropWorkItem)); + item->lmethod = callee_lmethod; + item->argument = call_argument; + queue = g_slist_prepend (queue, item); + } + } + } + g_free (operands); + + // Update nullability refcount information for the callee now + g_hash_table_insert (module->method_to_call_info, callee_method, call_site_union); + } + g_slist_free (calls); + + g_free (current); + } +#endif +} + /* * Emit the aot module into the LLVM bitcode file FILENAME. */ @@ -9409,6 +9621,8 @@ mono_llvm_emit_aot_module (const char *filename, const char *cu_name) MonoJumpInfo *ji; LLVMValueRef callee; + GHashTable *specializable = g_hash_table_new (NULL, NULL); + g_hash_table_iter_init (&iter, module->plt_entries_ji); while (g_hash_table_iter_next (&iter, (void**)&ji, (void**)&callee)) { if (mono_aot_is_direct_callable (ji)) { @@ -9418,10 +9632,17 @@ mono_llvm_emit_aot_module (const char *filename, const char *cu_name) /* The types might not match because the caller might pass an rgctx */ if (lmethod && LLVMTypeOf (callee) == LLVMTypeOf (lmethod)) { mono_llvm_replace_uses_of (callee, lmethod); + + if (!module->llvm_disable_self_init && mono_aot_can_specialize (ji->data.method)) + g_hash_table_insert (specializable, lmethod, ji->data.method); mono_aot_mark_unused_llvm_plt_entry (ji); } } } + + mono_llvm_propagate_nonnull_final (specializable, module); + + g_hash_table_destroy (specializable); } /* Note: You can still dump an invalid bitcode file by running `llvm-dis` -- 2.7.4