[wasm] Optimize Vector128<float>/<double>.Equals in interp/jiterp (#88064)
authorKatelyn Gadd <kg@luminance.org>
Wed, 12 Jul 2023 07:15:00 +0000 (00:15 -0700)
committerGitHub <noreply@github.com>
Wed, 12 Jul 2023 07:15:00 +0000 (00:15 -0700)
* Add browser-bench measurement for int32 and float equals
* Add interp intrinsics for Vector128 float and double Equals methods
* Implement Vector128 float and double Equals methods in jiterp
* Add jiterp validation to make sure we never appendSimd(0) by accident

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128_1.cs
src/mono/mono/mini/interp/interp-simd-intrins.def
src/mono/mono/mini/interp/interp-simd.c
src/mono/mono/mini/interp/simd-methods.def
src/mono/mono/mini/interp/transform-simd.c
src/mono/sample/wasm/browser-bench/Vector.cs
src/mono/wasm/runtime/jiterpreter-support.ts
src/mono/wasm/runtime/jiterpreter-trace-generator.ts
src/mono/wasm/runtime/jiterpreter.ts

index f07e898..576dd31 100644 (file)
@@ -388,6 +388,16 @@ namespace System.Runtime.Intrinsics
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public override bool Equals([NotNullWhen(true)] object? obj) => (obj is Vector128<T> other) && Equals(other);
 
+        // Account for floating-point equality around NaN
+        // This is in a separate method so it can be optimized by the mono interpreter/jiterpreter
+        [Intrinsic]
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static bool EqualsFloatingPoint (Vector128<T> lhs, Vector128<T> rhs)
+        {
+            Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
+            return result.AsInt32() == Vector128<int>.AllBitsSet;
+        }
+
         /// <summary>Determines whether the specified <see cref="Vector128{T}" /> is equal to the current instance.</summary>
         /// <param name="other">The <see cref="Vector128{T}" /> to compare with the current instance.</param>
         /// <returns><c>true</c> if <paramref name="other" /> is equal to the current instance; otherwise, <c>false</c>.</returns>
@@ -401,8 +411,7 @@ namespace System.Runtime.Intrinsics
             {
                 if ((typeof(T) == typeof(double)) || (typeof(T) == typeof(float)))
                 {
-                    Vector128<T> result = Vector128.Equals(this, other) | ~(Vector128.Equals(this, this) | Vector128.Equals(other, other));
-                    return result.AsInt32() == Vector128<int>.AllBitsSet;
+                    return EqualsFloatingPoint(this, other);
                 }
                 else
                 {
index c593c7e..33b7449 100644 (file)
@@ -58,6 +58,9 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_OR, interp_v128_o
 INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY, interp_v128_op_bitwise_equality, -1)
 INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_INEQUALITY, interp_v128_op_bitwise_inequality, -1)
 
+INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY, interp_v128_r4_float_equality, -1)
+INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY, interp_v128_r8_float_equality, -1)
+
 INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR, interp_v128_op_exclusive_or, 81)
 
 INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_I1_MULTIPLY, interp_v128_i1_op_multiply, -1)
index 2076680..65e60b4 100644 (file)
@@ -17,6 +17,7 @@ typedef guint16 v128_u2 __attribute__ ((vector_size (SIZEOF_V128)));
 typedef gint8 v128_i1 __attribute__ ((vector_size (SIZEOF_V128)));
 typedef guint8 v128_u1 __attribute__ ((vector_size (SIZEOF_V128)));
 typedef float v128_r4 __attribute__ ((vector_size (SIZEOF_V128)));
+typedef double v128_r8 __attribute__ ((vector_size (SIZEOF_V128)));
 
 // get_AllBitsSet
 static void
@@ -122,7 +123,30 @@ interp_v128_op_bitwise_inequality (gpointer res, gpointer v1, gpointer v2)
                *(gint32*)res = 1;
 }
 
-// op_Addition
+// Vector128<float>EqualsFloatingPoint
+static void
+interp_v128_r4_float_equality (gpointer res, gpointer v1, gpointer v2)
+{
+       v128_r4 v1_cast = *(v128_r4*)v1;
+       v128_r4 v2_cast = *(v128_r4*)v2;
+       v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
+       memset (&v1_cast, 0xff, SIZEOF_V128);
+
+       *(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
+}
+
+static void
+interp_v128_r8_float_equality (gpointer res, gpointer v1, gpointer v2)
+{
+       v128_r8 v1_cast = *(v128_r8*)v1;
+       v128_r8 v2_cast = *(v128_r8*)v2;
+       v128_r8 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
+       memset (&v1_cast, 0xff, SIZEOF_V128);
+
+       *(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
+}
+
+// op_Multiply
 static void
 interp_v128_i1_op_multiply (gpointer res, gpointer v1, gpointer v2)
 {
@@ -147,6 +171,7 @@ interp_v128_r4_op_multiply (gpointer res, gpointer v1, gpointer v2)
        *(v128_r4*)res = *(v128_r4*)v1 * *(v128_r4*)v2;
 }
 
+// op_Division
 static void
 interp_v128_r4_op_division (gpointer res, gpointer v1, gpointer v2)
 {
index f89785c..8c6fdc3 100644 (file)
@@ -29,6 +29,7 @@ SIMD_METHOD(CreateScalar)
 SIMD_METHOD(CreateScalarUnsafe)
 
 SIMD_METHOD(Equals)
+SIMD_METHOD(EqualsFloatingPoint)
 SIMD_METHOD(ExtractMostSignificantBits)
 SIMD_METHOD(GreaterThan)
 SIMD_METHOD(LessThan)
index 41e4940..255a2ab 100644 (file)
@@ -75,6 +75,7 @@ static guint16 sri_vector128_methods [] = {
 };
 
 static guint16 sri_vector128_t_methods [] = {
+       SN_EqualsFloatingPoint,
        SN_get_AllBitsSet,
        SN_get_Count,
        SN_get_One,
@@ -196,6 +197,13 @@ emit_common_simd_operations (TransformData *td, int id, int atype, int vector_si
                                *simd_intrins = INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY;
                        }
                        break;
+               case SN_EqualsFloatingPoint:
+                       *simd_opcode = MINT_SIMD_INTRINS_P_PP;
+                       if (atype == MONO_TYPE_R4)
+                               *simd_intrins = INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY;
+                       else if (atype == MONO_TYPE_R8)
+                               *simd_intrins = INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY;
+                       break;
                case SN_op_ExclusiveOr:
                        *simd_opcode = MINT_SIMD_INTRINS_P_PP;
                        *simd_intrins = INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR;
index cb04d36..c0ec294 100644 (file)
@@ -32,6 +32,8 @@ namespace Sample
                 new MinDouble(),
                 new MaxDouble(),
                 new Normalize(),
+                new EqualsInt32(),
+                new EqualsFloat(),
             };
         }
 
@@ -344,5 +346,41 @@ namespace Sample
                 result = vector / (float)Math.Sqrt(Vector128.Dot(vector, vector));
             }
         }
+
+        class EqualsInt32 : VectorMeasurement
+        {
+            Vector128<Int32> vector1, vector2;
+            bool result;
+
+            public override string Name => "Equals Int32";
+
+            public EqualsInt32()
+            {
+                vector1 = Vector128.Create(1, 2, 3, 4);
+                vector2 = Vector128.Create(4, 3, 2, 1);
+            }
+
+            public override void RunStep() {
+                result = vector1.Equals(vector2);
+            }
+        }
+
+        class EqualsFloat : VectorMeasurement
+        {
+            Vector128<float> vector1, vector2;
+            bool result;
+
+            public override string Name => "Equals Float";
+
+            public EqualsFloat()
+            {
+                vector1 = Vector128.Create(1f, 2f, 3f, 4f);
+                vector2 = Vector128.Create(4f, 3f, 2f, 1f);
+            }
+
+            public override void RunStep() {
+                result = vector1.Equals(vector2);
+            }
+        }
     }
 }
index 0b86607..aec7214 100644 (file)
@@ -297,9 +297,10 @@ export class WasmBuilder {
         return this.current.appendU8(value);
     }
 
-    appendSimd(value: WasmSimdOpcode) {
+    appendSimd(value: WasmSimdOpcode, allowLoad?: boolean) {
         this.current.appendU8(WasmOpcode.PREFIX_simd);
         // Yes that's right. We're using LEB128 to encode 8-bit opcodes. Why? I don't know
+        mono_assert(((value | 0) !== 0) || ((value === WasmSimdOpcode.v128_load) && (allowLoad === true)), "Expected non-v128_load simd opcode or allowLoad==true");
         return this.current.appendULeb(value);
     }
 
@@ -993,6 +994,7 @@ export class BlobBuilder {
     }
 
     appendULeb(value: number) {
+        mono_assert(typeof (value) === "number", () => `appendULeb expected number but got ${value}`);
         mono_assert(value >= 0, "cannot pass negative value to appendULeb");
         if (value < 0x7F) {
             if (this.size + 1 >= this.capacity)
@@ -1013,6 +1015,7 @@ export class BlobBuilder {
     }
 
     appendLeb(value: number) {
+        mono_assert(typeof (value) === "number", () => `appendLeb expected number but got ${value}`);
         if (this.size + 8 >= this.capacity)
             throw new Error("Buffer full");
 
@@ -1721,7 +1724,7 @@ export function try_append_memmove_fast(
         while (count >= sizeofV128) {
             builder.local(destLocal);
             builder.local(srcLocal);
-            builder.appendSimd(WasmSimdOpcode.v128_load);
+            builder.appendSimd(WasmSimdOpcode.v128_load, true);
             builder.appendMemarg(srcOffset, 0);
             builder.appendSimd(WasmSimdOpcode.v128_store);
             builder.appendMemarg(destOffset, 0);
index f795858..4b42c47 100644 (file)
@@ -1789,6 +1789,8 @@ function append_ldloc(builder: WasmBuilder, offset: number, opcodeOrPrefix: Wasm
     if (simdOpcode !== undefined) {
         // This looks wrong but I assure you it's correct.
         builder.appendULeb(simdOpcode);
+    } else if (opcodeOrPrefix === WasmOpcode.PREFIX_simd) {
+        throw new Error("PREFIX_simd ldloc without a simdOpcode");
     }
     const alignment = computeMemoryAlignment(offset, opcodeOrPrefix, simdOpcode);
     builder.appendMemarg(offset, alignment);
@@ -3493,7 +3495,7 @@ function emit_simd_2(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
             // Indirect load, so v1 is T** and res is Vector128*
             builder.local("pLocals");
             append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load);
-            builder.appendSimd(simple);
+            builder.appendSimd(simple, true);
             builder.appendMemarg(0, 0);
             append_simd_store(builder, ip);
         } else {
@@ -3609,6 +3611,33 @@ function emit_simd_3(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
                 builder.appendU8(WasmOpcode.i32_eqz);
             append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
             return true;
+        case SimdIntrinsic3.V128_R4_FLOAT_EQUALITY:
+        case SimdIntrinsic3.V128_R8_FLOAT_EQUALITY: {
+            /*
+            Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
+            return result.AsInt32() == Vector128<int>.AllBitsSet;
+            */
+            const isR8 = index === SimdIntrinsic3.V128_R8_FLOAT_EQUALITY,
+                eqOpcode = isR8 ? WasmSimdOpcode.f64x2_eq : WasmSimdOpcode.f32x4_eq;
+            builder.local("pLocals");
+            append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
+            builder.local("math_lhs128", WasmOpcode.tee_local);
+            append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
+            builder.local("math_rhs128", WasmOpcode.tee_local);
+            builder.appendSimd(eqOpcode);
+            builder.local("math_lhs128");
+            builder.local("math_lhs128");
+            builder.appendSimd(eqOpcode);
+            builder.local("math_rhs128");
+            builder.local("math_rhs128");
+            builder.appendSimd(eqOpcode);
+            builder.appendSimd(WasmSimdOpcode.v128_or);
+            builder.appendSimd(WasmSimdOpcode.v128_not);
+            builder.appendSimd(WasmSimdOpcode.v128_or);
+            builder.appendSimd(isR8 ? WasmSimdOpcode.i64x2_all_true : WasmSimdOpcode.i32x4_all_true);
+            append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
+            return true;
+        }
         case SimdIntrinsic3.V128_I1_SHUFFLE: {
             // Detect a constant indices vector and turn it into a const. This allows
             //  v8 to use a more optimized implementation of the swizzle opcode
index dfe1d41..f04f0b8 100644 (file)
@@ -795,8 +795,11 @@ function generate_wasm(
             "temp_f64": WasmValtype.f64,
             "backbranched": WasmValtype.i32,
         };
-        if (builder.options.enableSimd)
+        if (builder.options.enableSimd) {
             traceLocals["v128_zero"] = WasmValtype.v128;
+            traceLocals["math_lhs128"] = WasmValtype.v128;
+            traceLocals["math_rhs128"] = WasmValtype.v128;
+        }
 
         let keep = true,
             traceValue = 0;