[wasm] Exploit unallocated zero page to remove null checks in jiterpreter (#86403)
authorKatelyn Gadd <kg@luminance.org>
Wed, 31 May 2023 02:50:34 +0000 (19:50 -0700)
committerGitHub <noreply@github.com>
Wed, 31 May 2023 02:50:34 +0000 (19:50 -0700)
* Fuse null and length checks for getelema1 and getchr
* Sense whether zero page optimizations are safe based on the location of the emscripten stack and contents of memory near zero
* Improve jiterpreter stats formatting

src/mono/mono/utils/options-def.h
src/mono/wasm/runtime/cwraps.ts
src/mono/wasm/runtime/driver.c
src/mono/wasm/runtime/jiterpreter-support.ts
src/mono/wasm/runtime/jiterpreter-trace-generator.ts
src/mono/wasm/runtime/jiterpreter.ts

index 679fe9d..4217cbf 100644 (file)
@@ -112,6 +112,8 @@ DEFINE_BOOL(jiterpreter_eliminate_null_checks, "jiterpreter-eliminate-null-check
 DEFINE_BOOL(jiterpreter_backward_branches_enabled, "jiterpreter-backward-branches-enabled", TRUE, "Enable performing backward branches without exiting traces")
 // Attempt to use WASM v128 opcodes to implement SIMD interpreter opcodes
 DEFINE_BOOL(jiterpreter_enable_simd, "jiterpreter-simd-enabled", TRUE, "Attempt to use WebAssembly SIMD support")
+// Since the zero page is unallocated, loading array/string/span lengths from null ptrs will yield zero
+DEFINE_BOOL(jiterpreter_zero_page_optimization, "jiterpreter-zero-page-optimization", TRUE, "Exploit the zero page being unallocated to optimize out null checks")
 // When compiling a jit_call wrapper, bypass sharedvt wrappers if possible by inlining their
 //  logic into the compiled wrapper and calling the target AOTed function with native call convention
 DEFINE_BOOL(jiterpreter_direct_jit_call, "jiterpreter-direct-jit-calls", TRUE, "Bypass gsharedvt wrappers when compiling JIT call wrappers")
index f18a313..c8f33e2 100644 (file)
@@ -130,6 +130,7 @@ const fn_signatures: SigLine[] = [
     [true, "mono_jiterp_get_simd_opcode", "number", ["number", "number"]],
     [true, "mono_jiterp_get_arg_offset", "number", ["number", "number", "number"]],
     [true, "mono_jiterp_get_opcode_info", "number", ["number", "number"]],
+    [true, "mono_wasm_is_zero_page_reserved", "number", []],
     ...legacy_interop_cwraps
 ];
 
@@ -255,6 +256,7 @@ export interface t_Cwraps {
     mono_jiterp_get_simd_opcode(arity: number, index: number): number;
     mono_jiterp_get_arg_offset(imethod: number, sig: number, index: number): number;
     mono_jiterp_get_opcode_info(opcode: number, type: number): number;
+    mono_wasm_is_zero_page_reserved(): number;
 }
 
 const wrapped_c_functions: t_Cwraps = <any>{};
index a640e88..9dd130c 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 #include <emscripten.h>
+#include <emscripten/stack.h>
 #include <stdio.h>
 #include <stddef.h>
 #include <stdlib.h>
@@ -1413,3 +1414,14 @@ EMSCRIPTEN_KEEPALIVE double mono_wasm_get_f64_unaligned (const double *src) {
 EMSCRIPTEN_KEEPALIVE int32_t mono_wasm_get_i32_unaligned (const int32_t *src) {
        return *src;
 }
+
+EMSCRIPTEN_KEEPALIVE int mono_wasm_is_zero_page_reserved () {
+       // If the stack is above the first 512 bytes of memory this indicates that it is safe
+       //  to optimize out null checks for operations that also do a bounds check, like string
+       //  and array element loads. (We already know that Emscripten malloc will never allocate
+       //  data at 0.) This is the default behavior for Emscripten release builds and is
+       //  controlled by the emscripten GLOBAL_BASE option (default value 1024).
+       // clang/llvm may perform this optimization if --low-memory-unused is set.
+       // https://github.com/emscripten-core/emscripten/issues/19389
+       return (emscripten_stack_get_base() > 512) && (emscripten_stack_get_end() > 512);
+}
index e0b6283..2cfeab6 100644 (file)
@@ -1434,6 +1434,7 @@ export const counters = {
     failures: 0,
     bytesGenerated: 0,
     nullChecksEliminated: 0,
+    nullChecksFused: 0,
     backBranchesEmitted: 0,
     backBranchesNotEmitted: 0,
     simdFallback: simdFallbackCounters,
@@ -1777,6 +1778,22 @@ export function bytesFromHex(hex: string): Uint8Array {
     return bytes;
 }
 
+export function isZeroPageReserved(): boolean {
+    // FIXME: This check will always return true on worker threads.
+    // Right now the jiterpreter is disabled when threading is active, so that's not an issue.
+    if (!cwraps.mono_wasm_is_zero_page_reserved())
+        return false;
+
+    // Determine whether emscripten's stack checker or some other troublemaker has
+    //  written junk at the start of memory. The previous cwraps call will have
+    //  checked whether the stack starts at zero or not (on the main thread).
+    // We can't do this in the C helper because emcc/asan might be checking pointers.
+    return (Module.HEAPU32[0] === 0) &&
+        (Module.HEAPU32[1] === 0) &&
+        (Module.HEAPU32[2] === 0) &&
+        (Module.HEAPU32[3] === 0);
+}
+
 export type JiterpreterOptions = {
     enableAll?: boolean;
     enableTraces: boolean;
@@ -1786,6 +1803,7 @@ export type JiterpreterOptions = {
     enableCallResume: boolean;
     enableWasmEh: boolean;
     enableSimd: boolean;
+    zeroPageOptimization: boolean;
     // For locations where the jiterpreter heuristic says we will be unable to generate
     //  a trace, insert an entry point opcode anyway. This enables collecting accurate
     //  stats for options like estimateHeat, but raises overhead.
@@ -1828,6 +1846,7 @@ const optionNames: { [jsName: string]: string } = {
     "enableCallResume": "jiterpreter-call-resume-enabled",
     "enableWasmEh": "jiterpreter-wasm-eh-enabled",
     "enableSimd": "jiterpreter-simd-enabled",
+    "zeroPageOptimization": "jiterpreter-zero-page-optimization",
     "enableStats": "jiterpreter-stats-enabled",
     "disableHeuristic": "jiterpreter-disable-heuristic",
     "estimateHeat": "jiterpreter-estimate-heat",
index f486666..e286435 100644 (file)
@@ -22,6 +22,7 @@ import {
     append_memmove_dest_src, try_append_memset_fast,
     try_append_memmove_fast, counters, getOpcodeTableValue,
     getMemberOffset, JiterpMember, BailoutReason,
+    isZeroPageReserved
 } from "./jiterpreter-support";
 import { compileSimdFeatureDetect } from "./jiterpreter-feature-detect";
 import {
@@ -598,8 +599,25 @@ export function generateWasmBody(
                 append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.i32_load);
                 // stash it, we'll be using it multiple times
                 builder.local("math_lhs32", WasmOpcode.tee_local);
+
+                /*
+                const constantIndex = get_known_constant_value(getArgU16(ip, 3));
+                if (typeof (constantIndex) === "number")
+                    console.log(`getchr in ${builder.functions[0].name} with constant index ${constantIndex}`);
+                */
+
                 // str
-                append_ldloc_cknull(builder, getArgU16(ip, 2), ip, true);
+                const ptrLocal = builder.options.zeroPageOptimization ? "math_rhs32" : "cknull_ptr";
+                if (builder.options.zeroPageOptimization && isZeroPageReserved()) {
+                    // load string ptr and stash it
+                    // if the string ptr is null, the length check will fail and we will bail out,
+                    //  so the null check is not necessary
+                    counters.nullChecksFused++;
+                    append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load);
+                    builder.local(ptrLocal, WasmOpcode.tee_local);
+                } else
+                    append_ldloc_cknull(builder, getArgU16(ip, 2), ip, true);
+
                 // get string length
                 builder.appendU8(WasmOpcode.i32_load);
                 builder.appendMemarg(getMemberOffset(JiterpMember.StringLength), 2);
@@ -624,7 +642,7 @@ export function generateWasmBody(
                 builder.local("math_lhs32");
                 builder.i32_const(2);
                 builder.appendU8(WasmOpcode.i32_mul);
-                builder.local("cknull_ptr");
+                builder.local(ptrLocal);
                 builder.appendU8(WasmOpcode.i32_add);
                 // Load char
                 builder.appendU8(WasmOpcode.i32_load16_u);
@@ -1871,8 +1889,10 @@ function emit_fieldop(
             append_ldloca(builder, localOffset, sizeBytes);
             // src
             builder.local("cknull_ptr");
-            builder.i32_const(fieldOffset);
-            builder.appendU8(WasmOpcode.i32_add);
+            if (fieldOffset !== 0) {
+                builder.i32_const(fieldOffset);
+                builder.appendU8(WasmOpcode.i32_add);
+            }
             append_memmove_dest_src(builder, sizeBytes);
             return true;
         }
@@ -1880,8 +1900,10 @@ function emit_fieldop(
             const klass = get_imethod_data(frame, getArgU16(ip, 4));
             // dest = (char*)o + ip [3]
             builder.local("cknull_ptr");
-            builder.i32_const(fieldOffset);
-            builder.appendU8(WasmOpcode.i32_add);
+            if (fieldOffset !== 0) {
+                builder.i32_const(fieldOffset);
+                builder.appendU8(WasmOpcode.i32_add);
+            }
             // src = locals + ip [2]
             append_ldloca(builder, localOffset, 0);
             builder.ptr_const(klass);
@@ -1892,8 +1914,10 @@ function emit_fieldop(
             const sizeBytes = getArgU16(ip, 4);
             // dest
             builder.local("cknull_ptr");
-            builder.i32_const(fieldOffset);
-            builder.appendU8(WasmOpcode.i32_add);
+            if (fieldOffset !== 0) {
+                builder.i32_const(fieldOffset);
+                builder.appendU8(WasmOpcode.i32_add);
+            }
             // src
             append_ldloca(builder, localOffset, 0);
             append_memmove_dest_src(builder, sizeBytes);
@@ -1905,8 +1929,10 @@ function emit_fieldop(
             builder.local("pLocals");
             // cknull_ptr isn't always initialized here
             append_ldloc(builder, objectOffset, WasmOpcode.i32_load);
-            builder.i32_const(fieldOffset);
-            builder.appendU8(WasmOpcode.i32_add);
+            if (fieldOffset !== 0) {
+                builder.i32_const(fieldOffset);
+                builder.appendU8(WasmOpcode.i32_add);
+            }
             append_stloc_tail(builder, localOffset, setter);
             return true;
 
@@ -2812,18 +2838,35 @@ function append_getelema1(
 ) {
     builder.block();
 
+    /*
+    const constantIndex = get_known_constant_value(indexOffset);
+    if (typeof (constantIndex) === "number")
+        console.log(`getelema1 in ${builder.functions[0].name} with constant index ${constantIndex}`);
+    */
+
     // load index for check
     append_ldloc(builder, indexOffset, WasmOpcode.i32_load);
     // stash it since we need it twice
     builder.local("math_lhs32", WasmOpcode.tee_local);
-    // array null check
-    append_ldloc_cknull(builder, objectOffset, ip, true);
+
+    const ptrLocal = builder.options.zeroPageOptimization ? "math_rhs32" : "cknull_ptr";
+    if (builder.options.zeroPageOptimization && isZeroPageReserved()) {
+        // load array ptr and stash it
+        // if the array ptr is null, the length check will fail and we will bail out
+        counters.nullChecksFused++;
+        append_ldloc(builder, objectOffset, WasmOpcode.i32_load);
+        builder.local(ptrLocal, WasmOpcode.tee_local);
+    } else
+        // array null check
+        append_ldloc_cknull(builder, objectOffset, ip, true);
+
     // load array length
     builder.appendU8(WasmOpcode.i32_load);
     builder.appendMemarg(getMemberOffset(JiterpMember.ArrayLength), 2);
     // check index < array.length, unsigned. if index is negative it will be interpreted as
     //  a massive value which is naturally going to be bigger than array.length. interp.c
     //  exploits this property so we can too
+    // for a null array pointer array.length will also be zero thanks to the zero page optimization
     builder.appendU8(WasmOpcode.i32_lt_u);
     // bailout unless (index < array.length)
     builder.appendU8(WasmOpcode.br_if);
@@ -2832,7 +2875,7 @@ function append_getelema1(
     builder.endBlock();
 
     // We did a null check and bounds check so we can now compute the actual address
-    builder.local("cknull_ptr");
+    builder.local(ptrLocal);
     builder.i32_const(getMemberOffset(JiterpMember.ArrayData));
     builder.appendU8(WasmOpcode.i32_add);
 
@@ -2858,6 +2901,7 @@ function emit_arrayop(builder: WasmBuilder, frame: NativePointer, ip: MintOpcode
         case MintOpcode.MINT_LDLEN: {
             builder.local("pLocals");
             // array null check
+            // note: zero page optimization is not valid here since we want to throw on null
             append_ldloc_cknull(builder, objectOffset, ip, true);
             // load array length
             builder.appendU8(WasmOpcode.i32_load);
index caad542..8b79860 100644 (file)
@@ -10,7 +10,7 @@ import { MintOpcode } from "./mintops";
 import cwraps from "./cwraps";
 import {
     MintOpcodePtr, WasmValtype, WasmBuilder, addWasmFunctionPointer,
-    _now, elapsedTimes,
+    _now, elapsedTimes, isZeroPageReserved,
     counters, getRawCwrap, importDef,
     JiterpreterOptions, getOptions, recordFailure,
     JiterpMember, getMemberOffset,
@@ -1034,10 +1034,18 @@ export function jiterpreter_dump_stats(b?: boolean, concise?: boolean) {
     if (!mostRecentOptions.enableStats && (b !== undefined))
         return;
 
-    mono_log_info(`// jitted ${counters.bytesGenerated} bytes; ${counters.tracesCompiled} traces (${counters.traceCandidates} candidates, ${(counters.tracesCompiled / counters.traceCandidates * 100).toFixed(1)}%); ${counters.jitCallsCompiled} jit_calls (${(counters.directJitCallsCompiled / counters.jitCallsCompiled * 100).toFixed(1)}% direct); ${counters.entryWrappersCompiled} interp_entries`);
-    const backBranchHitRate = (counters.backBranchesEmitted / (counters.backBranchesEmitted + counters.backBranchesNotEmitted)) * 100;
-    const tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count();
-    mono_log_info(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm. ${counters.nullChecksEliminated} cknulls removed. ${counters.backBranchesEmitted} back-branches (${counters.backBranchesNotEmitted} failed, ${backBranchHitRate.toFixed(1)}%), ${tracesRejected} traces rejected`);
+    const backBranchHitRate = (counters.backBranchesEmitted / (counters.backBranchesEmitted + counters.backBranchesNotEmitted)) * 100,
+        tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count(),
+        nullChecksEliminatedText = mostRecentOptions.eliminateNullChecks ? counters.nullChecksEliminated.toString() : "off",
+        nullChecksFusedText = (mostRecentOptions.zeroPageOptimization ? counters.nullChecksFused.toString() + (isZeroPageReserved() ? "" : " (disabled)") : "off"),
+        backBranchesEmittedText = mostRecentOptions.enableBackwardBranches ? `emitted: ${counters.backBranchesEmitted}, failed: ${counters.backBranchesNotEmitted} (${backBranchHitRate.toFixed(1)}%)` : ": off",
+        directJitCallsText = counters.jitCallsCompiled ? (
+            mostRecentOptions.directJitCalls ? `direct jit calls: ${counters.directJitCallsCompiled} (${(counters.directJitCallsCompiled / counters.jitCallsCompiled * 100).toFixed(1)}%)` : "direct jit calls: off"
+        ) : "";
+
+    mono_log_info(`// jitted ${counters.bytesGenerated} bytes; ${counters.tracesCompiled} traces (${(counters.tracesCompiled / counters.traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${counters.jitCallsCompiled} jit_calls; ${counters.entryWrappersCompiled} interp_entries`);
+    mono_log_info(`// cknulls eliminated: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-branches ${backBranchesEmittedText}; ${directJitCallsText}`);
+    mono_log_info(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm.`);
     if (concise)
         return;