From: Katelyn Gadd Date: Wed, 12 Jul 2023 07:15:00 +0000 (-0700) Subject: [wasm] Optimize Vector128/.Equals in interp/jiterp (#88064) X-Git-Tag: accepted/tizen/unified/riscv/20231226.055536~1123 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d59af2cf097acb100ad6a9cba652be34be6f4a2e;p=platform%2Fupstream%2Fdotnet%2Fruntime.git [wasm] Optimize Vector128/.Equals in interp/jiterp (#88064) * Add browser-bench measurement for int32 and float equals * Add interp intrinsics for Vector128 float and double Equals methods * Implement Vector128 float and double Equals methods in jiterp * Add jiterp validation to make sure we never appendSimd(0) by accident --- diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128_1.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128_1.cs index f07e898..576dd31 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128_1.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128_1.cs @@ -388,6 +388,16 @@ namespace System.Runtime.Intrinsics [MethodImpl(MethodImplOptions.AggressiveInlining)] public override bool Equals([NotNullWhen(true)] object? obj) => (obj is Vector128 other) && Equals(other); + // Account for floating-point equality around NaN + // This is in a separate method so it can be optimized by the mono interpreter/jiterpreter + [Intrinsic] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static bool EqualsFloatingPoint (Vector128 lhs, Vector128 rhs) + { + Vector128 result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs)); + return result.AsInt32() == Vector128.AllBitsSet; + } + /// Determines whether the specified is equal to the current instance. /// The to compare with the current instance. /// true if is equal to the current instance; otherwise, false. @@ -401,8 +411,7 @@ namespace System.Runtime.Intrinsics { if ((typeof(T) == typeof(double)) || (typeof(T) == typeof(float))) { - Vector128 result = Vector128.Equals(this, other) | ~(Vector128.Equals(this, this) | Vector128.Equals(other, other)); - return result.AsInt32() == Vector128.AllBitsSet; + return EqualsFloatingPoint(this, other); } else { diff --git a/src/mono/mono/mini/interp/interp-simd-intrins.def b/src/mono/mono/mini/interp/interp-simd-intrins.def index c593c7e..33b7449 100644 --- a/src/mono/mono/mini/interp/interp-simd-intrins.def +++ b/src/mono/mono/mini/interp/interp-simd-intrins.def @@ -58,6 +58,9 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_OR, interp_v128_o INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY, interp_v128_op_bitwise_equality, -1) INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_INEQUALITY, interp_v128_op_bitwise_inequality, -1) +INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY, interp_v128_r4_float_equality, -1) +INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY, interp_v128_r8_float_equality, -1) + INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR, interp_v128_op_exclusive_or, 81) INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_I1_MULTIPLY, interp_v128_i1_op_multiply, -1) diff --git a/src/mono/mono/mini/interp/interp-simd.c b/src/mono/mono/mini/interp/interp-simd.c index 2076680..65e60b4 100644 --- a/src/mono/mono/mini/interp/interp-simd.c +++ b/src/mono/mono/mini/interp/interp-simd.c @@ -17,6 +17,7 @@ typedef guint16 v128_u2 __attribute__ ((vector_size (SIZEOF_V128))); typedef gint8 v128_i1 __attribute__ ((vector_size (SIZEOF_V128))); typedef guint8 v128_u1 __attribute__ ((vector_size (SIZEOF_V128))); typedef float v128_r4 __attribute__ ((vector_size (SIZEOF_V128))); +typedef double v128_r8 __attribute__ ((vector_size (SIZEOF_V128))); // get_AllBitsSet static void @@ -122,7 +123,30 @@ interp_v128_op_bitwise_inequality (gpointer res, gpointer v1, gpointer v2) *(gint32*)res = 1; } -// op_Addition +// Vector128EqualsFloatingPoint +static void +interp_v128_r4_float_equality (gpointer res, gpointer v1, gpointer v2) +{ + v128_r4 v1_cast = *(v128_r4*)v1; + v128_r4 v2_cast = *(v128_r4*)v2; + v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast)); + memset (&v1_cast, 0xff, SIZEOF_V128); + + *(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0; +} + +static void +interp_v128_r8_float_equality (gpointer res, gpointer v1, gpointer v2) +{ + v128_r8 v1_cast = *(v128_r8*)v1; + v128_r8 v2_cast = *(v128_r8*)v2; + v128_r8 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast)); + memset (&v1_cast, 0xff, SIZEOF_V128); + + *(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0; +} + +// op_Multiply static void interp_v128_i1_op_multiply (gpointer res, gpointer v1, gpointer v2) { @@ -147,6 +171,7 @@ interp_v128_r4_op_multiply (gpointer res, gpointer v1, gpointer v2) *(v128_r4*)res = *(v128_r4*)v1 * *(v128_r4*)v2; } +// op_Division static void interp_v128_r4_op_division (gpointer res, gpointer v1, gpointer v2) { diff --git a/src/mono/mono/mini/interp/simd-methods.def b/src/mono/mono/mini/interp/simd-methods.def index f89785c..8c6fdc3 100644 --- a/src/mono/mono/mini/interp/simd-methods.def +++ b/src/mono/mono/mini/interp/simd-methods.def @@ -29,6 +29,7 @@ SIMD_METHOD(CreateScalar) SIMD_METHOD(CreateScalarUnsafe) SIMD_METHOD(Equals) +SIMD_METHOD(EqualsFloatingPoint) SIMD_METHOD(ExtractMostSignificantBits) SIMD_METHOD(GreaterThan) SIMD_METHOD(LessThan) diff --git a/src/mono/mono/mini/interp/transform-simd.c b/src/mono/mono/mini/interp/transform-simd.c index 41e4940..255a2ab 100644 --- a/src/mono/mono/mini/interp/transform-simd.c +++ b/src/mono/mono/mini/interp/transform-simd.c @@ -75,6 +75,7 @@ static guint16 sri_vector128_methods [] = { }; static guint16 sri_vector128_t_methods [] = { + SN_EqualsFloatingPoint, SN_get_AllBitsSet, SN_get_Count, SN_get_One, @@ -196,6 +197,13 @@ emit_common_simd_operations (TransformData *td, int id, int atype, int vector_si *simd_intrins = INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY; } break; + case SN_EqualsFloatingPoint: + *simd_opcode = MINT_SIMD_INTRINS_P_PP; + if (atype == MONO_TYPE_R4) + *simd_intrins = INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY; + else if (atype == MONO_TYPE_R8) + *simd_intrins = INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY; + break; case SN_op_ExclusiveOr: *simd_opcode = MINT_SIMD_INTRINS_P_PP; *simd_intrins = INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR; diff --git a/src/mono/sample/wasm/browser-bench/Vector.cs b/src/mono/sample/wasm/browser-bench/Vector.cs index cb04d36..c0ec294 100644 --- a/src/mono/sample/wasm/browser-bench/Vector.cs +++ b/src/mono/sample/wasm/browser-bench/Vector.cs @@ -32,6 +32,8 @@ namespace Sample new MinDouble(), new MaxDouble(), new Normalize(), + new EqualsInt32(), + new EqualsFloat(), }; } @@ -344,5 +346,41 @@ namespace Sample result = vector / (float)Math.Sqrt(Vector128.Dot(vector, vector)); } } + + class EqualsInt32 : VectorMeasurement + { + Vector128 vector1, vector2; + bool result; + + public override string Name => "Equals Int32"; + + public EqualsInt32() + { + vector1 = Vector128.Create(1, 2, 3, 4); + vector2 = Vector128.Create(4, 3, 2, 1); + } + + public override void RunStep() { + result = vector1.Equals(vector2); + } + } + + class EqualsFloat : VectorMeasurement + { + Vector128 vector1, vector2; + bool result; + + public override string Name => "Equals Float"; + + public EqualsFloat() + { + vector1 = Vector128.Create(1f, 2f, 3f, 4f); + vector2 = Vector128.Create(4f, 3f, 2f, 1f); + } + + public override void RunStep() { + result = vector1.Equals(vector2); + } + } } } diff --git a/src/mono/wasm/runtime/jiterpreter-support.ts b/src/mono/wasm/runtime/jiterpreter-support.ts index 0b86607..aec7214 100644 --- a/src/mono/wasm/runtime/jiterpreter-support.ts +++ b/src/mono/wasm/runtime/jiterpreter-support.ts @@ -297,9 +297,10 @@ export class WasmBuilder { return this.current.appendU8(value); } - appendSimd(value: WasmSimdOpcode) { + appendSimd(value: WasmSimdOpcode, allowLoad?: boolean) { this.current.appendU8(WasmOpcode.PREFIX_simd); // Yes that's right. We're using LEB128 to encode 8-bit opcodes. Why? I don't know + mono_assert(((value | 0) !== 0) || ((value === WasmSimdOpcode.v128_load) && (allowLoad === true)), "Expected non-v128_load simd opcode or allowLoad==true"); return this.current.appendULeb(value); } @@ -993,6 +994,7 @@ export class BlobBuilder { } appendULeb(value: number) { + mono_assert(typeof (value) === "number", () => `appendULeb expected number but got ${value}`); mono_assert(value >= 0, "cannot pass negative value to appendULeb"); if (value < 0x7F) { if (this.size + 1 >= this.capacity) @@ -1013,6 +1015,7 @@ export class BlobBuilder { } appendLeb(value: number) { + mono_assert(typeof (value) === "number", () => `appendLeb expected number but got ${value}`); if (this.size + 8 >= this.capacity) throw new Error("Buffer full"); @@ -1721,7 +1724,7 @@ export function try_append_memmove_fast( while (count >= sizeofV128) { builder.local(destLocal); builder.local(srcLocal); - builder.appendSimd(WasmSimdOpcode.v128_load); + builder.appendSimd(WasmSimdOpcode.v128_load, true); builder.appendMemarg(srcOffset, 0); builder.appendSimd(WasmSimdOpcode.v128_store); builder.appendMemarg(destOffset, 0); diff --git a/src/mono/wasm/runtime/jiterpreter-trace-generator.ts b/src/mono/wasm/runtime/jiterpreter-trace-generator.ts index f795858..4b42c47 100644 --- a/src/mono/wasm/runtime/jiterpreter-trace-generator.ts +++ b/src/mono/wasm/runtime/jiterpreter-trace-generator.ts @@ -1789,6 +1789,8 @@ function append_ldloc(builder: WasmBuilder, offset: number, opcodeOrPrefix: Wasm if (simdOpcode !== undefined) { // This looks wrong but I assure you it's correct. builder.appendULeb(simdOpcode); + } else if (opcodeOrPrefix === WasmOpcode.PREFIX_simd) { + throw new Error("PREFIX_simd ldloc without a simdOpcode"); } const alignment = computeMemoryAlignment(offset, opcodeOrPrefix, simdOpcode); builder.appendMemarg(offset, alignment); @@ -3493,7 +3495,7 @@ function emit_simd_2(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins // Indirect load, so v1 is T** and res is Vector128* builder.local("pLocals"); append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load); - builder.appendSimd(simple); + builder.appendSimd(simple, true); builder.appendMemarg(0, 0); append_simd_store(builder, ip); } else { @@ -3609,6 +3611,33 @@ function emit_simd_3(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins builder.appendU8(WasmOpcode.i32_eqz); append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store); return true; + case SimdIntrinsic3.V128_R4_FLOAT_EQUALITY: + case SimdIntrinsic3.V128_R8_FLOAT_EQUALITY: { + /* + Vector128 result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs)); + return result.AsInt32() == Vector128.AllBitsSet; + */ + const isR8 = index === SimdIntrinsic3.V128_R8_FLOAT_EQUALITY, + eqOpcode = isR8 ? WasmSimdOpcode.f64x2_eq : WasmSimdOpcode.f32x4_eq; + builder.local("pLocals"); + append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load); + builder.local("math_lhs128", WasmOpcode.tee_local); + append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load); + builder.local("math_rhs128", WasmOpcode.tee_local); + builder.appendSimd(eqOpcode); + builder.local("math_lhs128"); + builder.local("math_lhs128"); + builder.appendSimd(eqOpcode); + builder.local("math_rhs128"); + builder.local("math_rhs128"); + builder.appendSimd(eqOpcode); + builder.appendSimd(WasmSimdOpcode.v128_or); + builder.appendSimd(WasmSimdOpcode.v128_not); + builder.appendSimd(WasmSimdOpcode.v128_or); + builder.appendSimd(isR8 ? WasmSimdOpcode.i64x2_all_true : WasmSimdOpcode.i32x4_all_true); + append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store); + return true; + } case SimdIntrinsic3.V128_I1_SHUFFLE: { // Detect a constant indices vector and turn it into a const. This allows // v8 to use a more optimized implementation of the swizzle opcode diff --git a/src/mono/wasm/runtime/jiterpreter.ts b/src/mono/wasm/runtime/jiterpreter.ts index dfe1d41..f04f0b8 100644 --- a/src/mono/wasm/runtime/jiterpreter.ts +++ b/src/mono/wasm/runtime/jiterpreter.ts @@ -795,8 +795,11 @@ function generate_wasm( "temp_f64": WasmValtype.f64, "backbranched": WasmValtype.i32, }; - if (builder.options.enableSimd) + if (builder.options.enableSimd) { traceLocals["v128_zero"] = WasmValtype.v128; + traceLocals["math_lhs128"] = WasmValtype.v128; + traceLocals["math_rhs128"] = WasmValtype.v128; + } let keep = true, traceValue = 0;