[wasm] Implement the ENDFINALLY opcode in the jiterpreter (#84273)
authorKatelyn Gadd <kg@luminance.org>
Wed, 5 Apr 2023 12:23:38 +0000 (05:23 -0700)
committerGitHub <noreply@github.com>
Wed, 5 Apr 2023 12:23:38 +0000 (05:23 -0700)
* Mark the opcode following CALL_HANDLER interpreter opcodes as a back branch target
* In the jiterpreter, record each CALL_HANDLER location when compiling them
* Then when compiling an ENDFINALLY opcode check to see whether the branch target is one we recognize and if so do a branch, otherwise bail out
* Tweak CFG to filter out branch targets that are never used
* Add browser-bench measurement for try-finally

src/mono/mono/mini/interp/jiterpreter.c
src/mono/mono/mini/interp/transform.c
src/mono/sample/wasm/browser-bench/Exceptions.cs
src/mono/wasm/runtime/jiterpreter-support.ts
src/mono/wasm/runtime/jiterpreter-trace-generator.ts
src/mono/wasm/runtime/jiterpreter.ts

index 9fa216f..bcd55db 100644 (file)
@@ -731,6 +731,10 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
 
                        return TRACE_CONTINUE;
 
+               case MINT_ENDFINALLY:
+                       // May produce either a backwards branch or a bailout
+                       return TRACE_CONDITIONAL_ABORT;
+
                case MINT_ICALL_V_P:
                case MINT_ICALL_V_V:
                case MINT_ICALL_P_P:
@@ -748,7 +752,6 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
                case MINT_LEAVE_S_CHECK:
                        return TRACE_ABORT;
 
-               case MINT_ENDFINALLY:
                case MINT_RETHROW:
                case MINT_PROF_EXIT:
                case MINT_PROF_EXIT_VOID:
index e5b8719..0c0ea50 100644 (file)
@@ -8434,6 +8434,17 @@ generate_compacted_code (InterpMethod *rtm, TransformData *td)
                        if (ins->opcode == MINT_TIER_PATCHPOINT_DATA) {
                                int native_offset = (int)(ip - td->new_code);
                                patchpoint_data_index = add_patchpoint_data (td, patchpoint_data_index, native_offset, -ins->data [0]);
+#if HOST_BROWSER
+                       } else if (rtm->contains_traces && (
+                               (ins->opcode == MINT_CALL_HANDLER_S) || (ins->opcode == MINT_CALL_HANDLER)
+                       )) {
+                               // While this formally isn't a backward branch target, we want to record
+                               //  the offset of its following instruction so that the jiterpreter knows
+                               //  to generate the necessary dispatch code to enable branching back to it.
+                               ip = emit_compacted_instruction (td, ip, ins);
+                               if (backward_branch_offsets_count < BACKWARD_BRANCH_OFFSETS_SIZE)
+                                       backward_branch_offsets[backward_branch_offsets_count++] = ip - td->new_code;
+#endif
                        } else {
                                ip = emit_compacted_instruction (td, ip, ins);
                        }
index 2619fc4..ffe96e8 100644 (file)
@@ -22,6 +22,7 @@ namespace Sample
                 new TryCatchFilterInline(),
                 new TryCatchFilterThrow(),
                 new TryCatchFilterThrowApplies(),
+                new TryFinally(),
             };
         }
 
@@ -208,5 +209,26 @@ namespace Sample
                     throw new System.Exception("Reached DoThrow and threw");
             }
         }
+
+        class TryFinally : ExcMeasurement
+        {
+            public override string Name => "TryFinally";
+            int j = 1;
+
+            public override void RunStep()
+            {
+                int i = 0;
+                try
+                {
+                    i += j;
+                }
+                finally
+                {
+                    i += j;
+                }
+                if (i != 2)
+                    throw new System.Exception("Internal error");
+            }
+        }
     }
 }
index fb481f3..e84a0fe 100644 (file)
@@ -50,6 +50,7 @@ export const enum BailoutReason {
     CallDelegate,
     Debugging,
     Icall,
+    UnexpectedRetIp,
 }
 
 export const BailoutReasonNames = [
@@ -78,6 +79,7 @@ export const BailoutReasonNames = [
     "CallDelegate",
     "Debugging",
     "Icall",
+    "UnexpectedRetIp",
 ];
 
 type FunctionType = [
@@ -157,6 +159,7 @@ export class WasmBuilder {
     options!: JiterpreterOptions;
     constantSlots: Array<number> = [];
     backBranchOffsets: Array<MintOpcodePtr> = [];
+    callHandlerReturnAddresses: Array<MintOpcodePtr> = [];
     nextConstantSlot = 0;
 
     compressImportNames = false;
@@ -202,6 +205,7 @@ export class WasmBuilder {
         for (let i = 0; i < this.constantSlots.length; i++)
             this.constantSlots[i] = 0;
         this.backBranchOffsets.length = 0;
+        this.callHandlerReturnAddresses.length = 0;
 
         this.allowNullCheckOptimization = this.options.eliminateNullChecks;
     }
@@ -1009,6 +1013,7 @@ class Cfg {
     blockStack: Array<MintOpcodePtr> = [];
     backDispatchOffsets: Array<MintOpcodePtr> = [];
     dispatchTable = new Map<MintOpcodePtr, number>();
+    observedBranchTargets = new Set<MintOpcodePtr>();
     trace = 0;
 
     constructor (builder: WasmBuilder) {
@@ -1025,6 +1030,7 @@ class Cfg {
         this.lastSegmentEnd = 0;
         this.overheadBytes = 10; // epilogue
         this.dispatchTable.clear();
+        this.observedBranchTargets.clear();
         this.trace = trace;
         this.backDispatchOffsets.length = 0;
     }
@@ -1071,6 +1077,7 @@ class Cfg {
     }
 
     branch (target: MintOpcodePtr, isBackward: boolean, isConditional: boolean) {
+        this.observedBranchTargets.add(target);
         this.appendBlob();
         this.segments.push({
             type: "branch",
@@ -1140,13 +1147,19 @@ class Cfg {
             this.backDispatchOffsets.length = 0;
             // First scan the back branch target table and union it with the block stack
             // This filters down to back branch targets that are reachable inside this trace
+            // Further filter it down by only including targets we have observed a branch to
+            //  this helps for cases where the branch opcodes targeting the location were not
+            //  compiled due to an abort or some other reason
             for (let i = 0; i < this.backBranchTargets.length; i++) {
                 const offset = (this.backBranchTargets[i] * 2) + <any>this.startOfBody;
                 const breakDepth = this.blockStack.indexOf(offset);
-                if (breakDepth >= 0) {
-                    this.dispatchTable.set(offset, this.backDispatchOffsets.length + 1);
-                    this.backDispatchOffsets.push(offset);
-                }
+                if (breakDepth < 0)
+                    continue;
+                if (!this.observedBranchTargets.has(offset))
+                    continue;
+
+                this.dispatchTable.set(offset, this.backDispatchOffsets.length + 1);
+                this.backDispatchOffsets.push(offset);
             }
 
             if (this.backDispatchOffsets.length === 0) {
index ed718df..758d180 100644 (file)
@@ -27,6 +27,7 @@ import {
     traceEip, nullCheckValidation,
     abortAtJittedLoopBodies, traceNullCheckOptimizations,
     nullCheckCaching, traceBackBranches,
+    maxCallHandlerReturnAddresses,
 
     mostRecentOptions,
 
@@ -345,12 +346,13 @@ export function generateWasmBody (
             case MintOpcode.MINT_CALL_HANDLER_S:
                 if (!emit_branch(builder, ip, frame, opcode))
                     ip = abort;
-                else
+                else {
                     // Technically incorrect, but the instructions following this one may not be executed
                     //  since we might have skipped over them.
                     // FIXME: Identify when we should actually set the conditionally executed flag, perhaps
                     //  by doing a simple static flow analysis based on the displacements. Update heuristic too!
                     isConditionallyExecuted = true;
+                }
                 break;
 
             case MintOpcode.MINT_CKNULL: {
@@ -923,13 +925,41 @@ export function generateWasmBody (
                 isLowValueOpcode = true;
                 break;
 
-            case MintOpcode.MINT_ENDFINALLY:
-                // This one might make sense to partially implement, but the jump target
-                //  is computed at runtime which would make it hard to figure out where
-                //  we need to put branch targets. Not worth just doing a conditional
-                //  bailout since finally blocks always run
-                ip = abort;
+            case MintOpcode.MINT_ENDFINALLY: {
+                if (
+                    (builder.callHandlerReturnAddresses.length > 0) &&
+                    (builder.callHandlerReturnAddresses.length <= maxCallHandlerReturnAddresses)
+                ) {
+                    // console.log(`endfinally @0x${(<any>ip).toString(16)}. return addresses:`, builder.callHandlerReturnAddresses.map(ra => (<any>ra).toString(16)));
+                    // FIXME: Clean this codegen up
+                    // Load ret_ip
+                    const clauseIndex = getArgU16(ip, 1),
+                        clauseDataOffset = get_imethod_clause_data_offset(frame, clauseIndex);
+                    builder.local("pLocals");
+                    builder.appendU8(WasmOpcode.i32_load);
+                    builder.appendMemarg(clauseDataOffset, 0);
+                    // Stash it in a variable because we're going to need to use it multiple times
+                    builder.local("math_lhs32", WasmOpcode.set_local);
+                    // Do a bunch of trivial comparisons to see if ret_ip is one of our expected return addresses,
+                    //  and if it is, generate a branch back to the dispatcher at the top
+                    for (let r = 0; r < builder.callHandlerReturnAddresses.length; r++) {
+                        const ra = builder.callHandlerReturnAddresses[r];
+                        builder.local("math_lhs32");
+                        builder.ptr_const(ra);
+                        builder.appendU8(WasmOpcode.i32_eq);
+                        builder.block(WasmValtype.void, WasmOpcode.if_);
+                        builder.cfg.branch(ra, ra < ip, true);
+                        builder.endBlock();
+                    }
+                    // If none of the comparisons succeeded we won't have branched anywhere, so bail out
+                    // This shouldn't happen during non-exception-handling execution unless the trace doesn't
+                    //  contain the CALL_HANDLER that led here
+                    append_bailout(builder, ip, BailoutReason.UnexpectedRetIp);
+                } else {
+                    ip = abort;
+                }
                 break;
+            }
 
             case MintOpcode.MINT_RETHROW:
             case MintOpcode.MINT_PROF_EXIT:
@@ -2444,7 +2474,8 @@ function append_call_handler_store_ret_ip (
     builder.appendU8(WasmOpcode.i32_store);
     builder.appendMemarg(clauseDataOffset, 0); // FIXME: 32-bit alignment?
 
-    // console.log(`call_handler clauseDataOffset=0x${clauseDataOffset.toString(16)} retIp=0x${retIp.toString(16)}`);
+    // console.log(`call_handler @0x${(<any>ip).toString(16)} retIp=0x${retIp.toString(16)}`);
+    builder.callHandlerReturnAddresses.push(retIp);
 }
 
 function emit_branch (
@@ -2496,10 +2527,14 @@ function emit_branch (
                     counters.backBranchesEmitted++;
                     return true;
                 } else {
-                    if ((traceBackBranches > 0) || (builder.cfg.trace > 0))
-                        console.log(`back branch target 0x${destination.toString(16)} not found in list ` +
+                    if (destination < builder.cfg.entryIp) {
+                        if ((traceBackBranches > 1) || (builder.cfg.trace > 1))
+                            console.log(`${info[0]} target 0x${destination.toString(16)} before start of trace`);
+                    } else if ((traceBackBranches > 0) || (builder.cfg.trace > 0))
+                        console.log(`0x${(<any>ip).toString(16)} ${info[0]} target 0x${destination.toString(16)} not found in list ` +
                             builder.backBranchOffsets.map(bbo => "0x" + (<any>bbo).toString(16)).join(", ")
                         );
+
                     cwraps.mono_jiterp_boost_back_branch_target(destination);
                     // FIXME: Should there be a safepoint here?
                     append_bailout(builder, destination, BailoutReason.BackwardBranch);
@@ -2586,8 +2621,11 @@ function emit_branch (
             builder.cfg.branch(destination, true, true);
             counters.backBranchesEmitted++;
         } else {
-            if ((traceBackBranches > 0) || (builder.cfg.trace > 0))
-                console.log(`back branch target 0x${destination.toString(16)} not found in list ` +
+            if (destination < builder.cfg.entryIp) {
+                if ((traceBackBranches > 1) || (builder.cfg.trace > 1))
+                    console.log(`${info[0]} target 0x${destination.toString(16)} before start of trace`);
+            } else if ((traceBackBranches > 0) || (builder.cfg.trace > 0))
+                console.log(`0x${(<any>ip).toString(16)} ${info[0]} target 0x${destination.toString(16)} not found in list ` +
                     builder.backBranchOffsets.map(bbo => "0x" + (<any>bbo).toString(16)).join(", ")
                 );
             // We didn't find a loop to branch to, so bail out
index 5a0e841..cd1ced5 100644 (file)
@@ -63,6 +63,13 @@ export const
     // Unproductive if we have backward branches enabled because it can stop us from jitting
     //  nested loops
     abortAtJittedLoopBodies = true,
+    // Enable generating conditional backward branches for ENDFINALLY opcodes if we saw some CALL_HANDLER
+    //  opcodes previously, up to this many potential return addresses. If a trace contains more potential
+    //  return addresses than this we will not emit code for the ENDFINALLY opcode
+    maxCallHandlerReturnAddresses = 3,
+    // Controls how many individual items (traces, bailouts, etc) are shown in the breakdown
+    //  at the end of a run when stats are enabled. The N highest ranking items will be shown.
+    summaryStatCount = 30,
     // Emit a wasm nop between each managed interpreter opcode
     emitPadding = false,
     // Generate compressed names for imports so that modules have more space for code
@@ -984,7 +991,7 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {
                 console.log(`// traces bailed out ${bailoutCount} time(s) due to ${BailoutReasonNames[i]}`);
         }
 
-        for (let i = 0, c = 0; i < traces.length && c < 30; i++) {
+        for (let i = 0, c = 0; i < traces.length && c < summaryStatCount; i++) {
             const trace = traces[i];
             if (!trace.bailoutCount)
                 continue;
@@ -1016,7 +1023,7 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {
             console.log("// hottest call targets:");
             const targetPointers = Object.keys(callTargetCounts);
             targetPointers.sort((l, r) => callTargetCounts[Number(r)] - callTargetCounts[Number(l)]);
-            for (let i = 0, c = Math.min(20, targetPointers.length); i < c; i++) {
+            for (let i = 0, c = Math.min(summaryStatCount, targetPointers.length); i < c; i++) {
                 const targetMethod = Number(targetPointers[i]) | 0;
                 const pMethodName = cwraps.mono_wasm_method_get_full_name(<any>targetMethod);
                 const targetMethodName = Module.UTF8ToString(pMethodName);
@@ -1028,7 +1035,7 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {
 
         traces.sort((l, r) => r.hitCount - l.hitCount);
         console.log("// hottest failed traces:");
-        for (let i = 0, c = 0; i < traces.length && c < 20; i++) {
+        for (let i = 0, c = 0; i < traces.length && c < summaryStatCount; i++) {
             // this means the trace has a low hit count and we don't know its identity. no value in
             //  logging it.
             if (!traces[i].name)
@@ -1064,7 +1071,6 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {
                     case "newobj_slow":
                     case "switch":
                     case "rethrow":
-                    case "endfinally":
                     case "end-of-body":
                     case "ret":
                         continue;
@@ -1072,7 +1078,6 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {
                     // not worth implementing / too difficult
                     case "intrins_marvin_block":
                     case "intrins_ascii_chars_to_uppercase":
-                    case "newarr":
                         continue;
                 }
             }