[wasm] Implement more jiterpreter opcodes (#82940)
authorKatelyn Gadd <kg@luminance.org>
Fri, 3 Mar 2023 16:59:02 +0000 (08:59 -0800)
committerGitHub <noreply@github.com>
Fri, 3 Mar 2023 16:59:02 +0000 (08:59 -0800)
* Implement a bunch of missing math opcodes like exp, log and pow
* Implement STELEM_VT

src/mono/mono/mini/interp/jiterpreter.c
src/mono/wasm/runtime/jiterpreter-trace-generator.ts
src/mono/wasm/runtime/jiterpreter.ts

index af7778f..89596d8 100644 (file)
@@ -786,7 +786,7 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
                )
                        return TRACE_CONTINUE;
                else if (
-                       // math intrinsics
+                       // math intrinsics - we implement most but not all of these
                        (opcode >= MINT_ASIN) &&
                        (opcode <= MINT_MAXF)
                )
@@ -1172,6 +1172,11 @@ mono_jiterp_math_atan2 (double lhs, double rhs) {
 }
 
 EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_pow (double lhs, double rhs) {
+       return pow(lhs, rhs);
+}
+
+EMSCRIPTEN_KEEPALIVE double
 mono_jiterp_math_acos (double value)
 {
        return acos(value);
@@ -1207,6 +1212,36 @@ mono_jiterp_math_tan (double value)
        return tan(value);
 }
 
+EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_exp (double value)
+{
+       return exp(value);
+}
+
+EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_log (double value)
+{
+       return log(value);
+}
+
+EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_log2 (double value)
+{
+       return log2(value);
+}
+
+EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_log10 (double value)
+{
+       return log10(value);
+}
+
+EMSCRIPTEN_KEEPALIVE double
+mono_jiterp_math_fma (double x, double y, double z)
+{
+       return fma(x, y, z);
+}
+
 EMSCRIPTEN_KEEPALIVE int
 mono_jiterp_stelem_ref (
        MonoArray *o, gint32 aindex, MonoObject *ref
index 3acaf10..8f0f438 100644 (file)
@@ -972,6 +972,7 @@ export function generate_wasm_body (
                 append_stloc_tail(builder, getArgU16(ip, 1), isI32 ? WasmOpcode.i32_store : WasmOpcode.i64_store);
                 break;
             }
+
             case MintOpcode.MINT_MONO_CMPXCHG_I4:
                 builder.local("pLocals");
                 append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load); // dest
@@ -992,6 +993,33 @@ export function generate_wasm_body (
                 builder.callImport("cmpxchg_i64");
                 break;
 
+            case MintOpcode.MINT_FMA:
+            case MintOpcode.MINT_FMAF: {
+                const isF32 = (opcode === MintOpcode.MINT_FMAF),
+                    loadOp = isF32 ? WasmOpcode.f32_load : WasmOpcode.f64_load,
+                    storeOp = isF32 ? WasmOpcode.f32_store : WasmOpcode.f64_store;
+
+                builder.local("pLocals");
+
+                // LOCAL_VAR (ip [1], double) = fma (LOCAL_VAR (ip [2], double), LOCAL_VAR (ip [3], double), LOCAL_VAR (ip [4], double));
+                append_ldloc(builder, getArgU16(ip, 2), loadOp);
+                if (isF32)
+                    builder.appendU8(WasmOpcode.f64_promote_f32);
+                append_ldloc(builder, getArgU16(ip, 3), loadOp);
+                if (isF32)
+                    builder.appendU8(WasmOpcode.f64_promote_f32);
+                append_ldloc(builder, getArgU16(ip, 4), loadOp);
+                if (isF32)
+                    builder.appendU8(WasmOpcode.f64_promote_f32);
+
+                builder.callImport("fma");
+
+                if (isF32)
+                    builder.appendU8(WasmOpcode.f32_demote_f64);
+                append_stloc_tail(builder, getArgU16(ip, 1), storeOp);
+                break;
+            }
+
             default:
                 if (
                     (
@@ -1064,7 +1092,7 @@ export function generate_wasm_body (
                     (opcode >= MintOpcode.MINT_LDELEM_I1) &&
                     (opcode <= MintOpcode.MINT_LDLEN)
                 ) {
-                    if (!emit_arrayop(builder, ip, opcode))
+                    if (!emit_arrayop(builder, frame, ip, opcode))
                         ip = abort;
                 } else if (
                     (opcode >= MintOpcode.MINT_BRFALSE_I4_SP) &&
@@ -2478,6 +2506,49 @@ function emit_relop_branch (builder: WasmBuilder, ip: MintOpcodePtr, opcode: Min
     return emit_branch(builder, ip, opcode, displacement);
 }
 
+const mathIntrinsicTable : { [opcode: number] : [isUnary: boolean, isF32: boolean, opcodeOrFuncName: WasmOpcode | string] } = {
+    [MintOpcode.MINT_SQRT]:     [true, false,  WasmOpcode.f64_sqrt],
+    [MintOpcode.MINT_SQRTF]:    [true, true,   WasmOpcode.f32_sqrt],
+    [MintOpcode.MINT_CEILING]:  [true, false,  WasmOpcode.f64_ceil],
+    [MintOpcode.MINT_CEILINGF]: [true, true,   WasmOpcode.f32_ceil],
+    [MintOpcode.MINT_FLOOR]:    [true, false,  WasmOpcode.f64_floor],
+    [MintOpcode.MINT_FLOORF]:   [true, true,   WasmOpcode.f32_floor],
+    [MintOpcode.MINT_ABS]:      [true, false,  WasmOpcode.f64_abs],
+    [MintOpcode.MINT_ABSF]:     [true, true,   WasmOpcode.f32_abs],
+    [MintOpcode.MINT_MIN]:      [true, false,  WasmOpcode.f64_min],
+    [MintOpcode.MINT_MINF]:     [true, true,   WasmOpcode.f32_min],
+    [MintOpcode.MINT_MAX]:      [true, false,  WasmOpcode.f64_max],
+    [MintOpcode.MINT_MAXF]:     [true, true,   WasmOpcode.f32_max],
+
+    [MintOpcode.MINT_ACOS]:     [true, false,  "acos"],
+    [MintOpcode.MINT_ACOSF]:    [true, true,   "acos"],
+    [MintOpcode.MINT_COS]:      [true, false,  "cos"],
+    [MintOpcode.MINT_COSF]:     [true, true,   "cos"],
+    [MintOpcode.MINT_ASIN]:     [true, false,  "asin"],
+    [MintOpcode.MINT_ASINF]:    [true, true,   "asin"],
+    [MintOpcode.MINT_SIN]:      [true, false,  "sin"],
+    [MintOpcode.MINT_SINF]:     [true, true,   "sin"],
+    [MintOpcode.MINT_ATAN]:     [true, false,  "atan"],
+    [MintOpcode.MINT_ATANF]:    [true, true,   "atan"],
+    [MintOpcode.MINT_TAN]:      [true, false,  "tan"],
+    [MintOpcode.MINT_TANF]:     [true, true,   "tan"],
+    [MintOpcode.MINT_EXP]:      [true, false,  "exp"],
+    [MintOpcode.MINT_EXPF]:     [true, true,   "exp"],
+    [MintOpcode.MINT_LOG]:      [true, false,  "log"],
+    [MintOpcode.MINT_LOGF]:     [true, true,   "log"],
+    [MintOpcode.MINT_LOG2]:     [true, false,  "log2"],
+    [MintOpcode.MINT_LOG2F]:    [true, true,   "log2"],
+    [MintOpcode.MINT_LOG10]:    [true, false,  "log10"],
+    [MintOpcode.MINT_LOG10F]:   [true, true,   "log10"],
+
+    [MintOpcode.MINT_ATAN2]:    [false, false, "atan2"],
+    [MintOpcode.MINT_ATAN2F]:   [false, true,  "atan2"],
+    [MintOpcode.MINT_POW]:      [false, false, "pow"],
+    [MintOpcode.MINT_POWF]:     [false, true,  "pow"],
+    [MintOpcode.MINT_REM_R4]:   [false, true,  "rem"],
+    [MintOpcode.MINT_REM_R8]:   [false, false, "rem"],
+};
+
 function emit_math_intrinsic (builder: WasmBuilder, ip: MintOpcodePtr, opcode: MintOpcode) : boolean {
     let isUnary : boolean, isF32 : boolean, name: string | undefined;
     let wasmOp : WasmOpcode | undefined;
@@ -2485,106 +2556,16 @@ function emit_math_intrinsic (builder: WasmBuilder, ip: MintOpcodePtr, opcode: M
         srcOffset = getArgU16(ip, 2),
         rhsOffset = getArgU16(ip, 3);
 
-    switch (opcode) {
-        // oddly the interpreter has no opcodes for abs!
-        case MintOpcode.MINT_SQRT:
-        case MintOpcode.MINT_SQRTF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_SQRTF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_sqrt
-                : WasmOpcode.f64_sqrt;
-            break;
-        case MintOpcode.MINT_CEILING:
-        case MintOpcode.MINT_CEILINGF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_CEILINGF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_ceil
-                : WasmOpcode.f64_ceil;
-            break;
-        case MintOpcode.MINT_FLOOR:
-        case MintOpcode.MINT_FLOORF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_FLOORF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_floor
-                : WasmOpcode.f64_floor;
-            break;
-        case MintOpcode.MINT_ABS:
-        case MintOpcode.MINT_ABSF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_ABSF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_abs
-                : WasmOpcode.f64_abs;
-            break;
-        case MintOpcode.MINT_REM_R4:
-        case MintOpcode.MINT_REM_R8:
-            isUnary = false;
-            isF32 = (opcode === MintOpcode.MINT_REM_R4);
-            name = "rem";
-            break;
-        case MintOpcode.MINT_ATAN2:
-        case MintOpcode.MINT_ATAN2F:
-            isUnary = false;
-            isF32 = (opcode === MintOpcode.MINT_ATAN2F);
-            name = "atan2";
-            break;
-        case MintOpcode.MINT_ACOS:
-        case MintOpcode.MINT_ACOSF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_ACOSF);
-            name = "acos";
-            break;
-        case MintOpcode.MINT_COS:
-        case MintOpcode.MINT_COSF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_COSF);
-            name = "cos";
-            break;
-        case MintOpcode.MINT_SIN:
-        case MintOpcode.MINT_SINF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_SINF);
-            name = "sin";
-            break;
-        case MintOpcode.MINT_ASIN:
-        case MintOpcode.MINT_ASINF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_ASINF);
-            name = "asin";
-            break;
-        case MintOpcode.MINT_TAN:
-        case MintOpcode.MINT_TANF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_TANF);
-            name = "tan";
-            break;
-        case MintOpcode.MINT_ATAN:
-        case MintOpcode.MINT_ATANF:
-            isUnary = true;
-            isF32 = (opcode === MintOpcode.MINT_ATANF);
-            name = "atan";
-            break;
-        case MintOpcode.MINT_MIN:
-        case MintOpcode.MINT_MINF:
-            isUnary = false;
-            isF32 = (opcode === MintOpcode.MINT_MINF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_min
-                : WasmOpcode.f64_min;
-            break;
-        case MintOpcode.MINT_MAX:
-        case MintOpcode.MINT_MAXF:
-            isUnary = false;
-            isF32 = (opcode === MintOpcode.MINT_MAXF);
-            wasmOp = isF32
-                ? WasmOpcode.f32_max
-                : WasmOpcode.f64_max;
-            break;
-        default:
-            return false;
+    const tableEntry = mathIntrinsicTable[opcode];
+    if (tableEntry) {
+        isUnary = tableEntry[0];
+        isF32 = tableEntry[1];
+        if (typeof (tableEntry[2]) === "string")
+            name = tableEntry[2];
+        else
+            wasmOp = tableEntry[2];
+    } else {
+        return false;
     }
 
     // Pre-load locals for the stloc at the end
@@ -2859,7 +2840,7 @@ function append_getelema1 (
     // append_getelema1 leaves the address on the stack
 }
 
-function emit_arrayop (builder: WasmBuilder, ip: MintOpcodePtr, opcode: MintOpcode) : boolean {
+function emit_arrayop (builder: WasmBuilder, frame: NativePointer, ip: MintOpcodePtr, opcode: MintOpcode) : boolean {
     const isLoad = (
             (opcode <= MintOpcode.MINT_LDELEMA_TC) &&
             (opcode >= MintOpcode.MINT_LDELEM_I1)
@@ -2977,6 +2958,17 @@ function emit_arrayop (builder: WasmBuilder, ip: MintOpcodePtr, opcode: MintOpco
             invalidate_local_range(getArgU16(ip, 1), elementSize);
             return true;
         }
+        case MintOpcode.MINT_STELEM_VT: {
+            const elementSize = getArgU16(ip, 5),
+                klass = get_imethod_data(frame, getArgU16(ip, 4));
+            // dest
+            append_getelema1(builder, ip, objectOffset, indexOffset, elementSize);
+            // src
+            append_ldloca(builder, valueOffset, 0);
+            builder.ptr_const(klass);
+            builder.callImport("value_copy");
+            return true;
+        }
         default:
             return false;
     }
index 35f5400..1ba27cf 100644 (file)
@@ -231,17 +231,22 @@ export let traceImports : Array<[string, string, Function]> | undefined;
 export let _wrap_trace_function: Function;
 
 const mathOps1 = [
+    "asin",
     "acos",
+    "atan",
     "cos",
+    "exp",
+    "log",
+    "log2",
+    "log10",
     "sin",
-    "asin",
     "tan",
-    "atan"
 ];
 
 const mathOps2 = [
     "rem",
     "atan2",
+    "pow",
 ];
 
 function getTraceImports () {
@@ -278,6 +283,7 @@ function getTraceImports () {
         importDef("cmpxchg_i32", getRawCwrap("mono_jiterp_cas_i32")),
         importDef("cmpxchg_i64", getRawCwrap("mono_jiterp_cas_i64")),
         importDef("stelem_ref", getRawCwrap("mono_jiterp_stelem_ref")),
+        importDef("fma", getRawCwrap("mono_jiterp_math_fma")),
     ];
 
     if (instrumentedMethodNames.length > 0) {
@@ -412,6 +418,13 @@ function initialize_builder (builder: WasmBuilder) {
         }, WasmValtype.f64, true
     );
     builder.defineType(
+        "fma", {
+            "x": WasmValtype.f64,
+            "y": WasmValtype.f64,
+            "z": WasmValtype.f64,
+        }, WasmValtype.f64, true
+    );
+    builder.defineType(
         "trace_eip", {
             "traceId": WasmValtype.i32,
             "eip": WasmValtype.i32,