[wasm] More accurate jiterpreter cfg size estimation; generate smaller dispatch table...
authorKatelyn Gadd <kg@luminance.org>
Thu, 23 Mar 2023 03:52:33 +0000 (20:52 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Mar 2023 03:52:33 +0000 (20:52 -0700)
* More accurate cfg size estimation
* Generate smaller dispatch tables for traces with backward branches
* Make sure we never actually can dispatch to the unreachable entries in the back branch table
* If we somehow generate a module bigger than 4KB, don't try to compile it. Just log a warning
* Better cfg logging for failed branches
* Add a separate runtime option that controls whether trace monitoring will print to the log

src/mono/mono/mini/interp/jiterpreter.c
src/mono/mono/utils/options-def.h
src/mono/wasm/runtime/jiterpreter-support.ts
src/mono/wasm/runtime/jiterpreter-trace-generator.ts
src/mono/wasm/runtime/jiterpreter.ts

index 31ba66e..c6374da 100644 (file)
@@ -1342,7 +1342,6 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) {
 }
 
 #define TRACE_PENALTY_LIMIT 200
-#define TRACE_MONITORING_DETAILED FALSE
 
 ptrdiff_t
 mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
@@ -1377,7 +1376,8 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
                int penalty = MIN ((int)((1.0f - scaled) * TRACE_PENALTY_LIMIT), TRACE_PENALTY_LIMIT);
                info->penalty_total += penalty;
 
-               // g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
+               if (mono_opt_jiterpreter_trace_monitoring_log > 2)
+                       g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
        }
 
        gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count;
@@ -1394,11 +1394,11 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
                        *(volatile JiterpreterThunk*)(ip + 1) = thunk;
                        mono_memory_barrier ();
                        *mutable_ip = MINT_TIER_ENTER_JITERPRETER;
-                       if (mono_opt_jiterpreter_stats_enabled && TRACE_MONITORING_DETAILED)
+                       if (mono_opt_jiterpreter_trace_monitoring_log > 1)
                                g_print ("trace #%d @%d '%s' accepted; average_penalty %f <= %f\n", index, ip, frame->imethod->method->name, average_penalty, threshold);
                } else {
                        traces_rejected++;
-                       if (mono_opt_jiterpreter_stats_enabled) {
+                       if (mono_opt_jiterpreter_trace_monitoring_log > 0) {
                                char * full_name = mono_method_get_full_name (frame->imethod->method);
                                g_print ("trace #%d @%d '%s' rejected; average_penalty %f > %f\n", index, ip, full_name, average_penalty, threshold);
                                g_free (full_name);
index 9637e05..b9454c0 100644 (file)
@@ -127,6 +127,8 @@ DEFINE_INT(jiterpreter_trace_monitoring_short_distance, "jiterpreter-trace-monit
 DEFINE_INT(jiterpreter_trace_monitoring_long_distance, "jiterpreter-trace-monitoring-long-distance", 10, "Traces that exit after processing this many opcodes have no exit penalty")
 // the average penalty value for a trace is compared against this threshold / 100 to decide whether to discard it
 DEFINE_INT(jiterpreter_trace_monitoring_max_average_penalty, "jiterpreter-trace-monitoring-max-average-penalty", 75, "If the average penalty value for a trace is above this value it will be rejected")
+// 0 = no monitoring, 1 = log when rejecting a trace, 2 = log when accepting or rejecting a trace, 3 = log every recorded bailout
+DEFINE_INT(jiterpreter_trace_monitoring_log, "jiterpreter-trace-monitoring-log", 0, "Logging detail level for trace monitoring")
 // After a do_jit_call call site is hit this many times, we will queue it to be jitted
 DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits")
 // After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue
index d9ee083..1486804 100644 (file)
@@ -1007,13 +1007,13 @@ class Cfg {
     entryBlob!: CfgBlob;
     blockStack: Array<MintOpcodePtr> = [];
     dispatchTable = new Map<MintOpcodePtr, number>();
-    trace = false;
+    trace = 0;
 
     constructor (builder: WasmBuilder) {
         this.builder = builder;
     }
 
-    initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: boolean) {
+    initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: number) {
         this.segments.length = 0;
         this.blockStack.length = 0;
         this.startOfBody = startOfBody;
@@ -1034,9 +1034,11 @@ class Cfg {
         mono_assert(this.segments[0].type === "blob", "expected blob");
         this.entryBlob = <CfgBlob>this.segments[0];
         this.segments.length = 0;
-        this.overheadBytes += 9; // entry eip init + block + optional loop
-        if (this.backBranchTargets)
-            this.overheadBytes += 24; // some extra padding for the dispatch br_table
+        this.overheadBytes += 9; // entry disp init + block + optional loop
+        if (this.backBranchTargets) {
+            this.overheadBytes += 20; // some extra padding for the dispatch br_table
+            this.overheadBytes += this.backBranchTargets.length; // one byte for each target in the table
+        }
     }
 
     appendBlob () {
@@ -1051,6 +1053,8 @@ class Cfg {
         });
         this.lastSegmentStartIp = this.ip;
         this.lastSegmentEnd = this.builder.current.size;
+        // each segment generates a block
+        this.overheadBytes += 2;
     }
 
     startBranchBlock (ip: MintOpcodePtr, isBackBranchTarget: boolean) {
@@ -1060,9 +1064,7 @@ class Cfg {
             ip,
             isBackBranchTarget,
         });
-        this.overheadBytes += 3; // each branch block just costs us a block (2 bytes) and an end
-        if (this.backBranchTargets)
-            this.overheadBytes += 3; // size of the br_table entry for this branch target
+        this.overheadBytes += 1; // each branch block just costs us an end
     }
 
     branch (target: MintOpcodePtr, isBackward: boolean, isConditional: boolean) {
@@ -1074,9 +1076,17 @@ class Cfg {
             isBackward,
             isConditional,
         });
-        this.overheadBytes += 3; // forward branches are a constant br + depth (optimally 2 bytes)
-        if (isBackward)
-            this.overheadBytes += 4; // back branches are more complex
+        // some branches will generate bailouts instead so we allocate 4 bytes per branch
+        //  to try and balance this out and avoid underestimating too much
+        this.overheadBytes += 4; // forward branches are a constant br + depth (optimally 2 bytes)
+        if (isBackward) {
+            // get_local <cinfo>
+            // i32_const 1
+            // i32_store 0 0
+            // i32.const <n>
+            // set_local <disp>
+            this.overheadBytes += 11;
+        }
     }
 
     emitBlob (segment: CfgBlob, source: Uint8Array) {
@@ -1135,11 +1145,21 @@ class Cfg {
             // br_table <number of values starting from 0> <labels for values starting from 0> <default>
             // we have to assign disp==0 to fallthrough so that we start at the top of the fn body, then
             //  assign disp values starting from 1 to branch targets
-            this.builder.appendULeb(this.blockStack.length + 1);
+            // FIXME: Only include back branch targets that are *also* in the block stack. This is necessary
+            //  when starting a trace in the middle of a method to make the table smaller
+            this.builder.appendULeb(this.backBranchTargets.length + 1);
             this.builder.appendULeb(1); // br depth of 1 = skip the unreachable and fall through to the start
-            for (let i = 0; i < this.blockStack.length; i++) {
-                this.dispatchTable.set(this.blockStack[i], i + 1);
-                this.builder.appendULeb(i + 2); // add 2 to the depth because of the double block around it
+            for (let i = 0; i < this.backBranchTargets.length; i++) {
+                const offset = (this.backBranchTargets[i] * 2) + <any>this.startOfBody;
+                const breakDepth = this.blockStack.indexOf(offset);
+                if (breakDepth >= 0) {
+                    this.dispatchTable.set(offset, i + 1);
+                    this.builder.appendULeb(breakDepth + 2); // add 2 to the depth because of the double block around it
+                } else {
+                    // This means the back branch target is outside of the trace. It shouldn't be possible to reach this
+                    //  and we didn't add it to the dispatch table anyway
+                    this.builder.appendULeb(0);
+                }
             }
             this.builder.appendULeb(0); // for unrecognized value we br 0, which causes us to trap
             this.builder.endBlock();
@@ -1150,7 +1170,7 @@ class Cfg {
             this.blockStack.push(dispatchIp);
         }
 
-        if (this.trace)
+        if (this.trace > 1)
             console.log(`blockStack=${this.blockStack}`);
 
         for (let i = 0; i < this.segments.length; i++) {
@@ -1173,14 +1193,15 @@ class Cfg {
                 }
                 case "branch": {
                     const lookupTarget = segment.isBackward ? dispatchIp : segment.target;
-                    let indexInStack = this.blockStack.indexOf(lookupTarget);
+                    let indexInStack = this.blockStack.indexOf(lookupTarget),
+                        successfulBackBranch = false;
 
                     // Back branches will target the dispatcher loop so we need to update the dispatch index
                     //  which will be used by the loop dispatch br_table to jump to the correct location
-                    if (segment.isBackward && (indexInStack >= 0)) {
+                    if (segment.isBackward) {
                         if (this.dispatchTable.has(segment.target)) {
                             const disp = this.dispatchTable.get(segment.target)!;
-                            if (this.trace)
+                            if (this.trace > 1)
                                 console.log(`backward br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)}: disp=${disp}`);
 
                             // set the backward branch taken flag in the cinfo so that the monitoring phase
@@ -1195,23 +1216,29 @@ class Cfg {
                             // set the dispatch index for the br_table
                             this.builder.i32_const(disp);
                             this.builder.local("disp", WasmOpcode.set_local);
+                            successfulBackBranch = true;
                         } else {
-                            if (this.trace)
+                            if (this.trace > 0)
                                 console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed: back branch target not in dispatch table`);
                             indexInStack = -1;
                         }
                     }
 
-                    if (indexInStack >= 0) {
+                    if ((indexInStack >= 0) || successfulBackBranch) {
                         // Conditional branches are nested in an extra block, so the depth is +1
                         const offset = segment.isConditional ? 1 : 0;
                         this.builder.appendU8(WasmOpcode.br);
                         this.builder.appendULeb(offset + indexInStack);
-                        if (this.trace)
+                        if (this.trace > 1)
                             console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} breaking out ${offset + indexInStack + 1} level(s)`);
                     } else {
-                        if (this.trace)
-                            console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed`);
+                        if (this.trace > 0) {
+                            const base = <any>this.base;
+                            if ((segment.target >= base) && (segment.target < this.exitIp))
+                                console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed (inside of trace!)`);
+                            else if (this.trace > 1)
+                                console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed (outside of trace 0x${base.toString(16)} - 0x${(<any>this.exitIp).toString(16)})`);
+                        }
                         append_bailout(this.builder, segment.target, BailoutReason.Branch);
                     }
                     break;
@@ -1289,7 +1316,7 @@ export function append_bailout (builder: WasmBuilder, ip: MintOpcodePtr, reason:
 
 // generate a bailout that is recorded for the monitoring phase as a possible early exit.
 export function append_exit (builder: WasmBuilder, ip: MintOpcodePtr, opcodeCounter: number, reason: BailoutReason) {
-    if (opcodeCounter <= (builder.options.monitoringLongDistance + 1)) {
+    if (opcodeCounter <= (builder.options.monitoringLongDistance + 2)) {
         builder.local("cinfo");
         builder.i32_const(opcodeCounter);
         builder.appendU8(WasmOpcode.i32_store);
index e588433..fe14f1b 100644 (file)
@@ -172,8 +172,8 @@ export function generate_wasm_body (
         // HACK: Browsers set a limit of 4KB, we lower it slightly since a single opcode
         //  might generate a ton of code and we generate a bit of an epilogue after
         //  we finish
-        const maxModuleSize = 3850,
-            spaceLeft = maxModuleSize - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes;
+        const maxBytesGenerated = 3840,
+            spaceLeft = maxBytesGenerated - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes;
         if (builder.size >= spaceLeft) {
             // console.log(`trace too big, estimated size is ${builder.size + builder.bytesGeneratedSoFar}`);
             record_abort(traceIp, ip, traceName, "trace-too-big");
index 852199a..d938478 100644 (file)
@@ -69,7 +69,9 @@ export const
     // Always grab method full names
     useFullNames = false,
     // Use the mono_debug_count() API (set the COUNT=n env var) to limit the number of traces to compile
-    useDebugCount = false;
+    useDebugCount = false,
+    // Web browsers limit synchronous module compiles to 4KB
+    maxModuleSize = 4080;
 
 export const callTargetCounts : { [method: number] : number } = {};
 
@@ -718,7 +720,7 @@ function generate_wasm (
                 if (getU16(ip) !== MintOpcode.MINT_TIER_PREPARE_JITERPRETER)
                     throw new Error(`Expected *ip to be MINT_TIER_PREPARE_JITERPRETER but was ${getU16(ip)}`);
 
-                builder.cfg.initialize(startOfBody, backwardBranchTable, !!instrument);
+                builder.cfg.initialize(startOfBody, backwardBranchTable, instrument ? 1 : 0);
 
                 // TODO: Call generate_wasm_body before generating any of the sections and headers.
                 // This will allow us to do things like dynamically vary the number of locals, in addition
@@ -754,6 +756,10 @@ function generate_wasm (
         if (trace > 0)
             console.log(`${(<any>(builder.base)).toString(16)} ${methodFullName || traceName} generated ${buffer.length} byte(s) of wasm`);
         counters.bytesGenerated += buffer.length;
+        if (buffer.length >= maxModuleSize) {
+            console.warn(`MONO_WASM: Jiterpreter generated too much code (${buffer.length} bytes) for trace ${traceName}. Please report this issue.`);
+            return 0;
+        }
         const traceModule = new WebAssembly.Module(buffer);
 
         const traceInstance = new WebAssembly.Instance(traceModule, {