[wasm] [jiterp] Use wasm if opcode for null checks and conditional branches (#88114)
authorKatelyn Gadd <kg@luminance.org>
Mon, 10 Jul 2023 05:20:40 +0000 (22:20 -0700)
committerGitHub <noreply@github.com>
Mon, 10 Jul 2023 05:20:40 +0000 (22:20 -0700)
Use wasm if opcode for null checks and conditional branches where necessary; use if and br_if directly where possible
Refactor branching implementation and move more into Cfg

src/mono/wasm/runtime/jiterpreter-support.ts
src/mono/wasm/runtime/jiterpreter-trace-generator.ts

index 6260174..0b86607 100644 (file)
@@ -1120,11 +1120,18 @@ type CfgBranch = {
     from: MintOpcodePtr;
     target: MintOpcodePtr;
     isBackward: boolean; // FIXME: This should be inferred automatically
-    isConditional: boolean;
+    branchType: CfgBranchType;
 }
 
 type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch;
 
+export const enum CfgBranchType {
+    Unconditional,
+    Conditional,
+    SafepointUnconditional,
+    SafepointConditional,
+}
+
 class Cfg {
     builder: WasmBuilder;
     startOfBody!: MintOpcodePtr;
@@ -1204,7 +1211,7 @@ class Cfg {
         this.overheadBytes += 1; // each branch block just costs us an end
     }
 
-    branch(target: MintOpcodePtr, isBackward: boolean, isConditional: boolean) {
+    branch(target: MintOpcodePtr, isBackward: boolean, branchType: CfgBranchType) {
         this.observedBranchTargets.add(target);
         this.appendBlob();
         this.segments.push({
@@ -1212,7 +1219,7 @@ class Cfg {
             from: this.ip,
             target,
             isBackward,
-            isConditional,
+            branchType: branchType,
         });
         // some branches will generate bailouts instead so we allocate 4 bytes per branch
         //  to try and balance this out and avoid underestimating too much
@@ -1225,6 +1232,14 @@ class Cfg {
             // set_local <disp>
             this.overheadBytes += 11;
         }
+
+        // Account for the size of the safepoint
+        if (
+            (branchType === CfgBranchType.SafepointConditional) ||
+            (branchType === CfgBranchType.SafepointUnconditional)
+        ) {
+            this.overheadBytes += 17;
+        }
     }
 
     emitBlob(segment: CfgBlob, source: Uint8Array) {
@@ -1387,10 +1402,32 @@ class Cfg {
                     }
 
                     if ((indexInStack >= 0) || successfulBackBranch) {
-                        // Conditional branches are nested in an extra block, so the depth is +1
-                        const offset = segment.isConditional ? 1 : 0;
-                        this.builder.appendU8(WasmOpcode.br);
+                        let offset = 0;
+                        switch (segment.branchType) {
+                            case CfgBranchType.SafepointUnconditional:
+                                append_safepoint(this.builder, segment.from);
+                                this.builder.appendU8(WasmOpcode.br);
+                                break;
+                            case CfgBranchType.SafepointConditional:
+                                // Wrap the safepoint + branch in an if
+                                this.builder.block(WasmValtype.void, WasmOpcode.if_);
+                                append_safepoint(this.builder, segment.from);
+                                this.builder.appendU8(WasmOpcode.br);
+                                offset = 1;
+                                break;
+                            case CfgBranchType.Unconditional:
+                                this.builder.appendU8(WasmOpcode.br);
+                                break;
+                            case CfgBranchType.Conditional:
+                                this.builder.appendU8(WasmOpcode.br_if);
+                                break;
+                            default:
+                                throw new Error("Unimplemented branch type");
+                        }
+
                         this.builder.appendULeb(offset + indexInStack);
+                        if (offset) // close the if
+                            this.builder.endBlock();
                         if (this.trace > 1)
                             mono_log_info(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} breaking out ${offset + indexInStack + 1} level(s)`);
                     } else {
@@ -1401,7 +1438,14 @@ class Cfg {
                             else if (this.trace > 1)
                                 mono_log_info(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed (outside of trace 0x${base.toString(16)} - 0x${(<any>this.exitIp).toString(16)})`);
                         }
+
+                        const isConditional = (segment.branchType === CfgBranchType.Conditional) ||
+                            (segment.branchType === CfgBranchType.SafepointConditional);
+                        if (isConditional)
+                            this.builder.block(WasmValtype.void, WasmOpcode.if_);
                         append_bailout(this.builder, segment.target, BailoutReason.Branch);
+                        if (isConditional)
+                            this.builder.endBlock();
                     }
                     break;
                 }
@@ -1475,6 +1519,20 @@ export const _now = (globalThis.performance && globalThis.performance.now)
 
 let scratchBuffer: NativePointer = <any>0;
 
+export function append_safepoint(builder: WasmBuilder, ip: MintOpcodePtr) {
+    // Check whether a safepoint is required
+    builder.ptr_const(cwraps.mono_jiterp_get_polling_required_address());
+    builder.appendU8(WasmOpcode.i32_load);
+    builder.appendMemarg(0, 2);
+    // If the polling flag is set we call mono_jiterp_do_safepoint()
+    builder.block(WasmValtype.void, WasmOpcode.if_);
+    builder.local("frame");
+    // Not ip_const, because we can't pass relative IP to do_safepoint
+    builder.i32_const(ip);
+    builder.callImport("safepoint");
+    builder.endBlock();
+}
+
 export function append_bailout(builder: WasmBuilder, ip: MintOpcodePtr, reason: BailoutReason) {
     builder.ip_const(ip);
     if (builder.options.countBailouts) {
index 0254a48..6711907 100644 (file)
@@ -22,7 +22,7 @@ import {
     append_memmove_dest_src, try_append_memset_fast,
     try_append_memmove_fast, counters, getOpcodeTableValue,
     getMemberOffset, JiterpMember, BailoutReason,
-    isZeroPageReserved
+    isZeroPageReserved, CfgBranchType, append_safepoint
 } from "./jiterpreter-support";
 import { compileSimdFeatureDetect } from "./jiterpreter-feature-detect";
 import {
@@ -1271,9 +1271,7 @@ export function generateWasmBody(
                         builder.local("index");
                         builder.ptr_const(ra);
                         builder.appendU8(WasmOpcode.i32_eq);
-                        builder.block(WasmValtype.void, WasmOpcode.if_);
-                        builder.cfg.branch(ra, ra < ip, true);
-                        builder.endBlock();
+                        builder.cfg.branch(ra, ra < ip, CfgBranchType.Conditional);
                     }
                     // If none of the comparisons succeeded we won't have branched anywhere, so bail out
                     // This shouldn't happen during non-exception-handling execution unless the trace doesn't
@@ -1889,11 +1887,10 @@ function append_ldloc_cknull(builder: WasmBuilder, localOffset: number, ip: Mint
         return;
     }
 
-    builder.block();
     append_ldloc(builder, localOffset, WasmOpcode.i32_load);
     builder.local("cknull_ptr", WasmOpcode.tee_local);
-    builder.appendU8(WasmOpcode.br_if);
-    builder.appendULeb(0);
+    builder.appendU8(WasmOpcode.i32_eqz);
+    builder.block(WasmValtype.void, WasmOpcode.if_);
     append_bailout(builder, ip, BailoutReason.NullCheck);
     builder.endBlock();
     if (leaveOnStack)
@@ -2697,7 +2694,7 @@ function emit_branch(
                         mono_log_info(`performing backward branch to 0x${destination.toString(16)}`);
                     if (isCallHandler)
                         append_call_handler_store_ret_ip(builder, ip, frame, opcode);
-                    builder.cfg.branch(destination, true, false);
+                    builder.cfg.branch(destination, true, CfgBranchType.Unconditional);
                     counters.backBranchesEmitted++;
                     return true;
                 } else {
@@ -2722,7 +2719,7 @@ function emit_branch(
                 builder.branchTargets.add(destination);
                 if (isCallHandler)
                     append_call_handler_store_ret_ip(builder, ip, frame, opcode);
-                builder.cfg.branch(destination, false, false);
+                builder.cfg.branch(destination, false, CfgBranchType.Unconditional);
                 return true;
             }
         }
@@ -2735,20 +2732,19 @@ function emit_branch(
         case MintOpcode.MINT_BRFALSE_I8_S: {
             const is64 = (opcode === MintOpcode.MINT_BRTRUE_I8_S) ||
                 (opcode === MintOpcode.MINT_BRFALSE_I8_S);
-            // Wrap the conditional branch in a block so we can skip the
-            //  actual branch at the end of it
-            builder.block();
+
+            // Load the condition
 
             displacement = getArgI16(ip, 2);
             append_ldloc(builder, getArgU16(ip, 1), is64 ? WasmOpcode.i64_load : WasmOpcode.i32_load);
             if (
-                (opcode === MintOpcode.MINT_BRTRUE_I4_S) ||
-                (opcode === MintOpcode.MINT_BRTRUE_I4_SP)
+                (opcode === MintOpcode.MINT_BRFALSE_I4_S) ||
+                (opcode === MintOpcode.MINT_BRFALSE_I4_SP)
             )
                 builder.appendU8(WasmOpcode.i32_eqz);
-            else if (opcode === MintOpcode.MINT_BRTRUE_I8_S)
-                builder.appendU8(WasmOpcode.i64_eqz);
             else if (opcode === MintOpcode.MINT_BRFALSE_I8_S) {
+                builder.appendU8(WasmOpcode.i64_eqz);
+            } else if (opcode === MintOpcode.MINT_BRTRUE_I8_S) {
                 // do (i64 == 0) == 0 because br_if can only branch on an i32 operand
                 builder.appendU8(WasmOpcode.i64_eqz);
                 builder.appendU8(WasmOpcode.i32_eqz);
@@ -2766,7 +2762,6 @@ function emit_branch(
             if (cwraps.mono_jiterp_get_opcode_info(opcode, OpcodeInfoType.Length) !== 4)
                 throw new Error(`Unsupported long branch opcode: ${getOpcodeName(opcode)}`);
 
-            builder.appendU8(WasmOpcode.i32_eqz);
             break;
         }
     }
@@ -2778,21 +2773,13 @@ function emit_branch(
 
     const destination = <any>ip + (displacement * 2);
 
-    // We generate a conditional branch that will skip past the rest of this
-    //  tiny branch dispatch block to avoid performing the branch
-    builder.appendU8(WasmOpcode.br_if);
-    builder.appendULeb(0);
-
     if (displacement < 0) {
-        if (isSafepoint)
-            append_safepoint(builder, ip);
-
         if (builder.backBranchOffsets.indexOf(destination) >= 0) {
             // We found a backwards branch target we can reach via our outer trace loop, so
             //  we update eip and branch out to the top of the loop block
             if (traceBackBranches > 1)
                 mono_log_info(`performing conditional backward branch to 0x${destination.toString(16)}`);
-            builder.cfg.branch(destination, true, true);
+            builder.cfg.branch(destination, true, isSafepoint ? CfgBranchType.SafepointConditional : CfgBranchType.Conditional);
             counters.backBranchesEmitted++;
         } else {
             if (destination < builder.cfg.entryIp) {
@@ -2804,19 +2791,17 @@ function emit_branch(
                 );
             // We didn't find a loop to branch to, so bail out
             cwraps.mono_jiterp_boost_back_branch_target(destination);
+            builder.block(WasmValtype.void, WasmOpcode.if_);
             append_bailout(builder, destination, BailoutReason.BackwardBranch);
+            builder.endBlock();
             counters.backBranchesNotEmitted++;
         }
     } else {
-        // Do a safepoint *before* changing our IP, if necessary
-        if (isSafepoint)
-            append_safepoint(builder, ip);
         // Branching is enabled, so set eip and exit the current branch block
         builder.branchTargets.add(destination);
-        builder.cfg.branch(destination, false, true);
+        builder.cfg.branch(destination, false, isSafepoint ? CfgBranchType.SafepointConditional : CfgBranchType.Conditional);
     }
 
-    builder.endBlock();
     return true;
 }
 
@@ -2838,10 +2823,6 @@ function emit_relop_branch(
     if (!relopInfo && !intrinsicFpBinop)
         return false;
 
-    // We have to wrap the computation of the branch condition inside the
-    //  branch block because opening blocks destroys the contents of the
-    //  wasm execution stack for some reason
-    builder.block();
     const displacement = getArgI16(ip, 3);
     if (traceBranchDisplacements)
         mono_log_info(`relop @${ip} displacement=${displacement}`);
@@ -3812,17 +3793,3 @@ function emit_simd_4(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
             return false;
     }
 }
-
-function append_safepoint(builder: WasmBuilder, ip: MintOpcodePtr) {
-    // Check whether a safepoint is required
-    builder.ptr_const(cwraps.mono_jiterp_get_polling_required_address());
-    builder.appendU8(WasmOpcode.i32_load);
-    builder.appendMemarg(0, 2);
-    // If the polling flag is set we call mono_jiterp_do_safepoint()
-    builder.block(WasmValtype.void, WasmOpcode.if_);
-    builder.local("frame");
-    // Not ip_const, because we can't pass relative IP to do_safepoint
-    builder.i32_const(ip);
-    builder.callImport("safepoint");
-    builder.endBlock();
-}