[wasm] Improve jiterpreter jit calls (#81071)
authorKatelyn Gadd <kg@luminance.org>
Wed, 25 Jan 2023 02:34:16 +0000 (18:34 -0800)
committerGitHub <noreply@github.com>
Wed, 25 Jan 2023 02:34:16 +0000 (18:34 -0800)
* Misc. cleanups and jiterpreter refactorings
* Optimize the ldloca + tnn.load pair in direct jit calls
* Fix argOffsets size calculation
* Turn stats back off

src/mono/mono/mini/interp/interp.c
src/mono/mono/mini/interp/jiterpreter.c
src/mono/mono/mini/interp/jiterpreter.h
src/mono/mono/utils/options-def.h
src/mono/wasm/runtime/cwraps.ts
src/mono/wasm/runtime/jiterpreter-jit-call.ts

index 61cb38404c849d14072f3a643a3517d37ffe4df9..e0da9d735c3cb4d5845bd092ce2c2ce798d86b09 100644 (file)
@@ -2659,8 +2659,11 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame
        cinfo = (JitCallInfo*)rmethod->jit_call_info;
 
 #if JITERPRETER_ENABLE_JIT_CALL_TRAMPOLINES
-       // FIXME: thread safety
+       // The jiterpreter will compile a unique thunk for each do_jit_call call site if it is hot
+       //  enough to justify it. At that point we can invoke the thunk to efficiently do most of
+       //  the work that would normally be done by do_jit_call
        if (mono_opt_jiterpreter_jit_call_enabled) {
+               // FIXME: Thread safety for the thunk pointer
                WasmJitCallThunk thunk = cinfo->jiterp_thunk;
                if (thunk) {
                        MonoFtnDesc ftndesc = {0};
@@ -2671,27 +2674,41 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame
                                mono_opt_jiterpreter_wasm_eh_enabled ||
                                (mono_aot_mode != MONO_AOT_MODE_LLVMONLY_INTERP)
                        ) {
+                               // WASM EH is available or we are otherwise in a situation where we know
+                               //  that the jiterpreter thunk was compiled with exception handling built-in
+                               //  so we can just invoke it directly and errors will be handled
                                thunk (ret_sp, sp, &ftndesc, &thrown);
                        } else {
+                               // Call a special JS function that will invoke the compiled jiterpreter thunk
+                               //  and trap errors for us to set the thrown flag
                                mono_interp_invoke_wasm_jit_call_trampoline (
                                        thunk, ret_sp, sp, &ftndesc, &thrown
                                );
                        }
                        interp_pop_lmf (&ext);
+
+                       // We reuse do_jit_call's epilogue to do things like propagate thrown exceptions
+                       //  and sign-extend return values instead of inlining that logic into every thunk
                        goto epilogue;
                } else {
+                       // FIXME: thread safety for the hit count
                        int count = cinfo->hit_count;
+                       // If our hit count just reached the threshold, we request that a thunk be jitted
+                       //  for this specific call site. It will go into a queue and wait until there
+                       //  are enough jit calls waiting to be compiled into one WASM module
                        if (count == mono_opt_jiterpreter_jit_call_trampoline_hit_count) {
-                               void *fn = cinfo->no_wrapper ? cinfo->addr : cinfo->wrapper;
                                mono_interp_jit_wasm_jit_call_trampoline (
-                                       rmethod->method, rmethod, cinfo, fn,
-                                       rmethod->hasthis, rmethod->param_count,
+                                       rmethod->method, rmethod, cinfo,
                                        rmethod->arg_offsets, mono_aot_mode == MONO_AOT_MODE_LLVMONLY_INTERP
                                );
                        } else {
                                int excess = count - mono_opt_jiterpreter_jit_call_queue_flush_threshold;
                                if (excess <= 0)
                                        cinfo->hit_count++;
+                               // If our hit count just reached the flush threshold, that means that we
+                               //  previously requested compilation for this call site and it didn't
+                               //  happen yet. We will request a flush of the entire queue this one
+                               //  time which will probably result in it being compiled
                                if (excess == 0)
                                        mono_interp_flush_jitcall_queue ();
                        }
index e338dd5fae9c7c9dd462ff7820e6362569372475..50d4745dcf71aca904bf3109ab1f0ec56bf13490 100644 (file)
@@ -956,6 +956,12 @@ mono_jiterp_get_hashcode (MonoObject ** ppObj)
        return mono_object_hash_internal (obj);
 }
 
+EMSCRIPTEN_KEEPALIVE int
+mono_jiterp_get_signature_has_this (MonoMethodSignature *sig)
+{
+       return sig->hasthis;
+}
+
 EMSCRIPTEN_KEEPALIVE MonoType *
 mono_jiterp_get_signature_return_type (MonoMethodSignature *sig)
 {
index 5961651d0908bfdc99d7a618e84bf4a0a69f1a24..81504080693dfcf6214fc51b67d04e6a8af4be38 100644 (file)
@@ -63,8 +63,7 @@ mono_interp_tier_prepare_jiterpreter (
 //  or JitCallInfo
 extern void
 mono_interp_jit_wasm_jit_call_trampoline (
-       MonoMethod *method, void *rmethod, void *cinfo, void *func,
-       gboolean has_this, int param_count,
+       MonoMethod *method, void *rmethod, void *cinfo,
        guint32 *arg_offsets, gboolean catch_exceptions
 );
 
index 265504c2a183f3a42990969e9959a61b5b914943..6e130d744a13f21929b099ee3c31d3ba5469f1e9 100644 (file)
@@ -95,7 +95,7 @@ DEFINE_BOOL(jiterpreter_call_resume_enabled, "jiterpreter-call-resume-enabled",
 //  stats for options like estimateHeat, but raises overhead.
 DEFINE_BOOL(jiterpreter_disable_heuristic, "jiterpreter-disable-heuristic", FALSE, "Always insert trace entry points for more accurate statistics")
 // Automatically prints stats at app exit or when jiterpreter_dump_stats is called
-DEFINE_BOOL(jiterpreter_stats_enabled, "jiterpreter-stats-enabled", TRUE, "Automatically print jiterpreter statistics")
+DEFINE_BOOL(jiterpreter_stats_enabled, "jiterpreter-stats-enabled", FALSE, "Automatically print jiterpreter statistics")
 // Continue counting hits for traces that fail to compile and use it to estimate
 //  the relative importance of the opcode that caused them to abort
 DEFINE_BOOL(jiterpreter_estimate_heat, "jiterpreter-estimate-heat", FALSE, "Maintain accurate hit count for all trace entry points")
index 86c2c8c291a5b53e16cd07f495a7e1ea524f9b5a..b5b25a5617633dfaf8fd674e253854236df6c7c9 100644 (file)
@@ -115,6 +115,7 @@ const fn_signatures: SigLine[] = [
     [true, "mono_jiterp_register_jit_call_thunk", "void", ["number", "number"]],
     [true, "mono_jiterp_type_get_raw_value_size", "number", ["number"]],
     [true, "mono_jiterp_update_jit_call_dispatcher", "void", ["number"]],
+    [true, "mono_jiterp_get_signature_has_this", "number", ["number"]],
     [true, "mono_jiterp_get_signature_return_type", "number", ["number"]],
     [true, "mono_jiterp_get_signature_param_count", "number", ["number"]],
     [true, "mono_jiterp_get_signature_params", "number", ["number"]],
@@ -253,6 +254,7 @@ export interface t_Cwraps {
     mono_jiterp_adjust_abort_count(opcode: number, delta: number): number;
     mono_jiterp_register_jit_call_thunk(cinfo: number, func: number): void;
     mono_jiterp_update_jit_call_dispatcher(fn: number): void;
+    mono_jiterp_get_signature_has_this(sig: VoidPtr): number;
     mono_jiterp_get_signature_return_type(sig: VoidPtr): MonoType;
     mono_jiterp_get_signature_param_count(sig: VoidPtr): number;
     mono_jiterp_get_signature_params(sig: VoidPtr): VoidPtr;
index b17426121a9a3632dbcb5668720b8087bed1bdb1..ad70a480055159f993e5ebe9307b309e4b846adb 100644 (file)
@@ -43,7 +43,7 @@ struct _JitCallInfo {
 
 const offsetOfAddr = 0,
     // offsetOfExtraArg = 4,
-    // offsetOfWrapper = 8,
+    offsetOfWrapper = 8,
     offsetOfSig = 12,
     offsetOfArgInfo = 16,
     offsetOfRetMt = 24,
@@ -51,8 +51,7 @@ const offsetOfAddr = 0,
     JIT_ARG_BYVAL = 0;
 
 const maxJitQueueLength = 6,
-    maxSharedQueueLength = 12,
-    flushParamThreshold = 7;
+    maxSharedQueueLength = 12;
     // sizeOfStackval = 8;
 
 let trampBuilder : WasmBuilder;
@@ -70,53 +69,63 @@ class TrampolineInfo {
     hasThisReference: boolean;
     hasReturnValue: boolean;
     noWrapper: boolean;
+    // The number of managed arguments (not including the this-reference or return val address)
     paramCount: number;
+    // The managed type of each argument, not including the this-reference
+    paramTypes: MonoType[];
+    // The interpreter stack offset of each argument, in bytes. Indexes are one-based if
+    //  the method has a this-reference (thisp is arg 0) and zero-based for static methods.
+    // The return value address is not in here either because it's always at a fixed location.
     argOffsets: number[];
     catchExceptions: boolean;
     target: number; // either cinfo->wrapper or cinfo->addr, depending
     addr: number; // always cinfo->addr
+    wrapper: number; // always cinfo->wrapper
     name: string;
     result: number;
     queue: NativePointer[] = [];
     signature: VoidPtr;
-    signatureParamCount: number;
-    signatureParamTypes: MonoType[];
-    signatureReturnType: MonoType;
+    returnType: MonoType;
     wasmNativeReturnType: WasmValtype;
     wasmNativeSignature: WasmValtype[];
     enableDirect: boolean;
 
     constructor (
         method: MonoMethod, rmethod: VoidPtr, cinfo: VoidPtr,
-        has_this: boolean, param_count: number,
-        arg_offsets: VoidPtr, catch_exceptions: boolean, func: number
+        arg_offsets: VoidPtr, catch_exceptions: boolean
     ) {
         this.method = method;
         this.rmethod = rmethod;
+        this.catchExceptions = catch_exceptions;
         this.cinfo = cinfo;
         this.addr = getU32(<any>cinfo + offsetOfAddr);
-        this.hasThisReference = has_this;
-        this.paramCount = param_count;
-        this.catchExceptions = catch_exceptions;
-        this.argOffsets = new Array(param_count);
+        this.wrapper = getU32(<any>cinfo + offsetOfWrapper);
         this.signature = <any>getU32(<any>cinfo + offsetOfSig);
-        this.signatureReturnType = cwraps.mono_jiterp_get_signature_return_type(this.signature);
-        this.signatureParamCount = cwraps.mono_jiterp_get_signature_param_count(this.signature);
         this.noWrapper = getU8(<any>cinfo + offsetOfNoWrapper) !== 0;
-        const ptr = cwraps.mono_jiterp_get_signature_params(this.signature);
-        this.signatureParamTypes = new Array(this.signatureParamCount);
-        for (let i = 0; i < this.signatureParamCount; i++)
-            this.signatureParamTypes[i] = <any>getU32(<any>ptr + (i * 4));
         this.hasReturnValue = getI32(<any>cinfo + offsetOfRetMt) !== -1;
-        for (let i = 0, c = param_count + (has_this ? 1 : 0); i < c; i++)
+
+        this.returnType = cwraps.mono_jiterp_get_signature_return_type(this.signature);
+        this.paramCount = cwraps.mono_jiterp_get_signature_param_count(this.signature);
+        this.hasThisReference = cwraps.mono_jiterp_get_signature_has_this(this.signature) !== 0;
+
+        const ptr = cwraps.mono_jiterp_get_signature_params(this.signature);
+        this.paramTypes = new Array(this.paramCount);
+        for (let i = 0; i < this.paramCount; i++)
+            this.paramTypes[i] = <any>getU32(<any>ptr + (i * 4));
+
+        // See initialize_arg_offsets for where this array is built
+        const argOffsetCount = this.paramCount + (this.hasThisReference ? 1 : 0);
+        this.argOffsets = new Array(this.paramCount);
+        for (let i = 0; i < argOffsetCount; i++)
             this.argOffsets[i] = <any>getU32(<any>arg_offsets + (i * 4));
-        this.target = func;
+
+        this.target = this.noWrapper ? this.addr : this.wrapper;
         this.result = 0;
 
-        this.wasmNativeReturnType = this.signatureReturnType && this.hasReturnValue
-            ? (wasmTypeFromCilOpcode as any)[cwraps.mono_jiterp_type_to_stind(this.signatureReturnType)]
+        this.wasmNativeReturnType = this.returnType && this.hasReturnValue
+            ? (wasmTypeFromCilOpcode as any)[cwraps.mono_jiterp_type_to_stind(this.returnType)]
             : WasmValtype.void;
-        this.wasmNativeSignature = this.signatureParamTypes.map(
+        this.wasmNativeSignature = this.paramTypes.map(
             monoType => (wasmTypeFromCilOpcode as any)[cwraps.mono_jiterp_type_to_ldind(monoType)]
         );
         this.enableDirect = getOptions().directJitCalls &&
@@ -174,8 +183,7 @@ export function mono_interp_invoke_wasm_jit_call_trampoline (
 }
 
 export function mono_interp_jit_wasm_jit_call_trampoline (
-    method: MonoMethod, rmethod: VoidPtr, cinfo: VoidPtr, func: number,
-    has_this: number, param_count: number,
+    method: MonoMethod, rmethod: VoidPtr, cinfo: VoidPtr,
     arg_offsets: VoidPtr, catch_exceptions: number
 ) : void {
     // multiple cinfos can share the same target function, so for that scenario we want to
@@ -202,8 +210,8 @@ export function mono_interp_jit_wasm_jit_call_trampoline (
     }
 
     const info = new TrampolineInfo(
-        method, rmethod, cinfo, has_this !== 0, param_count,
-        arg_offsets, catch_exceptions !== 0, func
+        method, rmethod, cinfo,
+        arg_offsets, catch_exceptions !== 0
     );
     targetCache[cacheKey] = info;
     jitQueue.push(info);
@@ -211,9 +219,7 @@ export function mono_interp_jit_wasm_jit_call_trampoline (
     // we don't want the queue to get too long, both because jitting too many trampolines
     //  at once can hit the 4kb limit and because it makes it more likely that we will
     //  fail to jit them early enough
-    // HACK: we also want to flush the queue when we get a function with many parameters,
-    //  since it's going to generate a lot more code and push us closer to 4kb
-    if ((info.paramCount >= flushParamThreshold) || (jitQueue.length >= maxJitQueueLength))
+    if (jitQueue.length >= maxJitQueueLength)
         mono_interp_flush_jitcall_queue();
 }
 
@@ -343,10 +349,6 @@ export function mono_interp_flush_jitcall_queue () : void {
     let rejected = true, threw = false;
 
     const trampImports : Array<[string, string, Function | number]> = [
-        ["stackSave", "stackSave", Module.stackSave],
-        ["stackAlloc", "stackAlloc", Module.stackAlloc],
-        ["stackRestore", "stackRestore", Module.stackRestore],
-        ["trace_entry", "trace_entry", mono_jiterp_trace_wrapper_entry],
     ];
 
     try {
@@ -366,32 +368,10 @@ export function mono_interp_flush_jitcall_queue () : void {
                 "thrown": WasmValtype.i32,
             }, WasmValtype.void
         );
-        builder.defineType(
-            "stackAlloc", {
-                "bytes": WasmValtype.i32,
-            }, WasmValtype.i32
-        );
-        builder.defineType(
-            "stackSave", {
-            }, WasmValtype.i32
-        );
-        builder.defineType(
-            "stackRestore", {
-                "sp": WasmValtype.i32,
-            }, WasmValtype.void
-        );
-        builder.defineType(
-            "trace_entry", {
-                "nameIndex": WasmValtype.i32,
-                "expected": WasmValtype.i32,
-                "actual": WasmValtype.i32,
-            }, WasmValtype.void
-        );
 
         for (let i = 0; i < jitQueue.length; i++) {
             const info = jitQueue[i];
 
-            const actualParamCount = (info.hasThisReference ? 1 : 0) + (info.hasReturnValue ? 1 : 0) + info.paramCount;
             const sig : any = {};
 
             if (info.enableDirect) {
@@ -403,8 +383,12 @@ export function mono_interp_flush_jitcall_queue () : void {
 
                 sig["rgctx"] = WasmValtype.i32;
             } else {
+                const actualParamCount = (info.hasThisReference ? 1 : 0) +
+                    (info.hasReturnValue ? 1 : 0) + info.paramCount;
+
                 for (let j = 0; j < actualParamCount; j++)
                     sig[`arg${j}`] = WasmValtype.i32;
+
                 sig["ftndesc"] = WasmValtype.i32;
             }
 
@@ -534,6 +518,9 @@ export function mono_interp_flush_jitcall_queue () : void {
         // FIXME
         if (threw || (!rejected && ((trace >= 2) || dumpWrappers))) {
             console.log(`// MONO_WASM: ${jitQueue.length} jit call wrappers generated, blob follows //`);
+            for (let i = 0; i < jitQueue.length; i++)
+                console.log(`// #${i} === ${jitQueue[i].name} hasThis=${jitQueue[i].hasThisReference} hasRet=${jitQueue[i].hasReturnValue} wasmArgTypes=${jitQueue[i].wasmNativeSignature}`);
+
             let s = "", j = 0;
             try {
                 if (builder.inSection)
@@ -689,7 +676,10 @@ function generate_wasm_body (
             mono_mb_emit_ldarg (mb, 0);
     */
     if (info.hasThisReference) {
-        append_ldloc(builder, 0, WasmOpcode.i32_load);
+        // The this-reference is always the first argument
+        // Note that currently info.argOffsets[0] will always be 0, but it's best to
+        //  read it from the array in case this behavior changes later.
+        append_ldloc(builder, info.argOffsets[0], WasmOpcode.i32_load);
         stack_index++;
     }
 
@@ -707,17 +697,11 @@ function generate_wasm_body (
             // pass the first four bytes of the stackval data union,
             //  which is 'p' where pointers live
             append_ldloc(builder, svalOffset, WasmOpcode.i32_load);
-        } else {
-            // pass the address of the stackval data union
-            append_ldloca(builder, svalOffset);
-        }
-
-        if (info.enableDirect) {
+        } else if (info.enableDirect) {
             // The wrapper call convention is byref for all args. Now we convert it to the native calling convention
-            const loadCilOp = cwraps.mono_jiterp_type_to_ldind(info.signatureParamTypes[i]);
-            mono_assert(loadCilOp, () => `No load opcode for ${info.signatureParamTypes[i]}`);
+            const loadCilOp = cwraps.mono_jiterp_type_to_ldind(info.paramTypes[i]);
+            mono_assert(loadCilOp, () => `No load opcode for ${info.paramTypes[i]}`);
 
-            // We already performed a ldarg up above, so now we have the address that would've been passed to the wrapper
             /*
                 if (m_type_is_byref (sig->params [i])) {
                     mono_mb_emit_ldarg (mb, args_start + i);
@@ -732,19 +716,21 @@ function generate_wasm_body (
             */
 
             if (loadCilOp === CilOpcodes.DUMMY_BYREF) {
-                // Nothing to do
+                // pass the address of the stackval data union
+                append_ldloca(builder, svalOffset);
             } else {
                 const loadWasmOp = (wasmOpcodeFromCilOpcode as any)[loadCilOp];
                 if (!loadWasmOp) {
-                    console.error(`No wasm load op for arg #${i} type ${info.signatureParamTypes[i]} cil opcode ${loadCilOp}`);
+                    console.error(`No wasm load op for arg #${i} type ${info.paramTypes[i]} cil opcode ${loadCilOp}`);
                     return false;
                 }
 
                 // FIXME: LDOBJ is not implemented
-                // TODO: Optimize ldloca->this into a single load-with-offset
-                builder.appendU8(loadWasmOp);
-                builder.appendMemarg(0, 0);
+                append_ldloc(builder, svalOffset, loadWasmOp);
             }
+        } else {
+            // pass the address of the stackval data union
+            append_ldloca(builder, svalOffset);
         }
     }
 
@@ -756,6 +742,10 @@ function generate_wasm_body (
     mono_mb_emit_byte (mb, CEE_LDIND_I);
     */
 
+    // We have to pass the ftndesc through from do_jit_call because the target function needs
+    //  a rgctx value, which is not constant for a given wrapper if the target function is shared
+    //  for multiple InterpMethods. We pass ftndesc instead of rgctx so that we can pass the
+    //  address to gsharedvt wrappers without having to do our own stackAlloc
     builder.local("ftndesc");
     if (info.enableDirect || info.noWrapper) {
         // Native calling convention wants an rgctx, not a ftndesc. The rgctx
@@ -790,10 +780,10 @@ function generate_wasm_body (
 
     // The stack should now contain [ret_sp, retval], so write retval through the return address
     if (info.hasReturnValue && info.enableDirect) {
-        const storeCilOp = cwraps.mono_jiterp_type_to_stind(info.signatureReturnType);
+        const storeCilOp = cwraps.mono_jiterp_type_to_stind(info.returnType);
         const storeWasmOp = (wasmOpcodeFromCilOpcode as any)[storeCilOp];
         if (!storeWasmOp) {
-            console.error(`No wasm store op for return type ${info.signatureReturnType} cil opcode ${storeCilOp}`);
+            console.error(`No wasm store op for return type ${info.returnType} cil opcode ${storeCilOp}`);
             return false;
         }