From afc2f118e7177107655c5f7bf5d02df9953a2742 Mon Sep 17 00:00:00 2001 From: Simon Rozsival Date: Fri, 25 Feb 2022 10:52:08 +0100 Subject: [PATCH] [Mono] Add SIMD intrinsic for Vector64/128 comparisons (#65128) * Add vector comparison intrinsics * Add EqualsAll and EqualsAny intrinsics * Remove broken EqualsAny * Fix EqualsAny * Enable xequal also for arm64 * Fix xzero type * Fix bad merge * Add guards for invalid types * Revert unrelated change * Extract duplicate code blocks to a new function * Fix EqualsAny * Fix typo + code improvements --- src/mono/mono/mini/mini-llvm.c | 149 ++++++++++++++++++----------------- src/mono/mono/mini/simd-intrinsics.c | 112 +++++++++++++++++++++++--- src/mono/mono/mini/simd-methods.h | 2 + 3 files changed, 178 insertions(+), 85 deletions(-) diff --git a/src/mono/mono/mini/mini-llvm.c b/src/mono/mono/mini/mini-llvm.c index 9486592..0a85c0c 100644 --- a/src/mono/mono/mini/mini-llvm.c +++ b/src/mono/mono/mini/mini-llvm.c @@ -9465,79 +9465,6 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb) values [ins->dreg] = LLVMBuildSExt (builder, cmp, LLVMTypeOf (lhs), ""); break; } - case OP_XEQUAL: { - LLVMTypeRef t; - LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle; - int nelems; - -#if defined(TARGET_WASM) - /* The wasm code generator doesn't understand the shuffle/and code sequence below */ - LLVMValueRef val; - if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) { - val = LLVMIsNull (lhs) ? rhs : lhs; - nelems = LLVMGetVectorSize (LLVMTypeOf (lhs)); - - IntrinsicId intrins = (IntrinsicId)0; - switch (nelems) { - case 16: - intrins = INTRINS_WASM_ANYTRUE_V16; - break; - case 8: - intrins = INTRINS_WASM_ANYTRUE_V8; - break; - case 4: - intrins = INTRINS_WASM_ANYTRUE_V4; - break; - case 2: - intrins = INTRINS_WASM_ANYTRUE_V2; - break; - default: - g_assert_not_reached (); - } - /* res = !wasm.anytrue (val) */ - values [ins->dreg] = call_intrins (ctx, intrins, &val, ""); - values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname); - break; - } -#endif - LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs)); - - //%c = icmp sgt <16 x i8> %a0, %a1 - if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ()) - cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, ""); - else - cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, ""); - nelems = LLVMGetVectorSize (LLVMTypeOf (cmp)); - - LLVMTypeRef elemt; - if (srcelemt == LLVMDoubleType ()) - elemt = LLVMInt64Type (); - else if (srcelemt == LLVMFloatType ()) - elemt = LLVMInt32Type (); - else - elemt = srcelemt; - - t = LLVMVectorType (elemt, nelems); - cmp = LLVMBuildSExt (builder, cmp, t, ""); - // cmp is a vector, each element is either 0xff... or 0 - int half = nelems / 2; - while (half >= 1) { - // AND the top and bottom halfes into the bottom half - for (int i = 0; i < half; ++i) - mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE); - for (int i = half; i < nelems; ++i) - mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE); - shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), ""); - cmp = LLVMBuildAnd (builder, cmp, shuffle, ""); - half = half / 2; - } - // Extract [0] - LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""); - // convert to 0/1 - LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), ""); - values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), ""); - break; - } case OP_POPCNT32: values [ins->dreg] = call_intrins (ctx, INTRINS_CTPOP_I32, &lhs, ""); break; @@ -9616,6 +9543,82 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb) } #endif +#if defined(TARGET_ARM64) || defined(TARGET_X86) || defined(TARGET_AMD64) || defined(TARGET_WASM) + case OP_XEQUAL: { + LLVMTypeRef t; + LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle; + int nelems; + +#if defined(TARGET_WASM) + /* The wasm code generator doesn't understand the shuffle/and code sequence below */ + LLVMValueRef val; + if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) { + val = LLVMIsNull (lhs) ? rhs : lhs; + nelems = LLVMGetVectorSize (LLVMTypeOf (lhs)); + + IntrinsicId intrins = (IntrinsicId)0; + switch (nelems) { + case 16: + intrins = INTRINS_WASM_ANYTRUE_V16; + break; + case 8: + intrins = INTRINS_WASM_ANYTRUE_V8; + break; + case 4: + intrins = INTRINS_WASM_ANYTRUE_V4; + break; + case 2: + intrins = INTRINS_WASM_ANYTRUE_V2; + break; + default: + g_assert_not_reached (); + } + /* res = !wasm.anytrue (val) */ + values [ins->dreg] = call_intrins (ctx, intrins, &val, ""); + values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname); + break; + } +#endif + LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs)); + + //%c = icmp sgt <16 x i8> %a0, %a1 + if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ()) + cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, ""); + else + cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, ""); + nelems = LLVMGetVectorSize (LLVMTypeOf (cmp)); + + LLVMTypeRef elemt; + if (srcelemt == LLVMDoubleType ()) + elemt = LLVMInt64Type (); + else if (srcelemt == LLVMFloatType ()) + elemt = LLVMInt32Type (); + else + elemt = srcelemt; + + t = LLVMVectorType (elemt, nelems); + cmp = LLVMBuildSExt (builder, cmp, t, ""); + // cmp is a vector, each element is either 0xff... or 0 + int half = nelems / 2; + while (half >= 1) { + // AND the top and bottom halfes into the bottom half + for (int i = 0; i < half; ++i) + mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE); + for (int i = half; i < nelems; ++i) + mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE); + shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), ""); + cmp = LLVMBuildAnd (builder, cmp, shuffle, ""); + half = half / 2; + } + // Extract [0] + LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""); + // convert to 0/1 + LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), ""); + values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), ""); + break; + } +#endif + #if defined(TARGET_ARM64) case OP_XOP_I4_I4: diff --git a/src/mono/mono/mini/simd-intrinsics.c b/src/mono/mono/mini/simd-intrinsics.c index 4bc0199..39342e8 100644 --- a/src/mono/mono/mini/simd-intrinsics.c +++ b/src/mono/mono/mini/simd-intrinsics.c @@ -260,6 +260,29 @@ emit_xcompare (MonoCompile *cfg, MonoClass *klass, MonoTypeEnum etype, MonoInst return ins; } +static MonoInst* +emit_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2) +{ + return emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg); +} + +static MonoInst* +emit_not_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2) +{ + MonoInst *ins = emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg); + int sreg = ins->dreg; + int dreg = alloc_ireg (cfg); + MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0); + EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1); + return ins; +} + +static MonoInst* +emit_xzero (MonoCompile *cfg, MonoClass *klass) +{ + return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1); +} + static gboolean is_intrinsics_vector_type (MonoType *vector_type) { @@ -492,7 +515,7 @@ emit_vector_create_elementwise ( { int op = type_to_insert_op (etype); MonoClass *vklass = mono_class_from_mono_type_internal (vtype); - MonoInst *ins = emit_simd_ins (cfg, vklass, OP_XZERO, -1, -1); + MonoInst *ins = emit_xzero (cfg, vklass); for (int i = 0; i < fsig->param_count; ++i) { ins = emit_simd_ins (cfg, vklass, op, ins->dreg, args [i]->dreg); ins->inst_c0 = i; @@ -590,10 +613,17 @@ static guint16 sri_vector_methods [] = { SN_CreateScalar, SN_CreateScalarUnsafe, SN_Divide, + SN_Equals, + SN_EqualsAll, + SN_EqualsAny, SN_Floor, SN_GetElement, SN_GetLower, SN_GetUpper, + SN_GreaterThan, + SN_GreaterThanOrEqual, + SN_LessThan, + SN_LessThanOrEqual, SN_Max, SN_Min, SN_Multiply, @@ -788,6 +818,27 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR, -1, arg0_type, fsig, args); case SN_CreateScalarUnsafe: return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR_UNSAFE, -1, arg0_type, fsig, args); + case SN_Equals: + case SN_EqualsAll: + case SN_EqualsAny: { + MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]); + if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type)) + return NULL; + + switch (id) { + case SN_Equals: + return emit_xcompare (cfg, klass, arg0_type, args [0], args [1]); + case SN_EqualsAll: + return emit_xequal (cfg, klass, args [0], args [1]); + case SN_EqualsAny: { + MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]); + MonoInst *cmp_eq = emit_xcompare (cfg, arg_class, arg0_type, args [0], args [1]); + MonoInst *zero = emit_xzero (cfg, arg_class); + return emit_not_xequal (cfg, arg_class, cmp_eq, zero); + } + default: g_assert_not_reached (); + } + } case SN_GetElement: { MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]); MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0]; @@ -809,6 +860,34 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi int op = id == SN_GetLower ? OP_XLOWER : OP_XUPPER; return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args); } + case SN_GreaterThan: + case SN_GreaterThanOrEqual: + case SN_LessThan: + case SN_LessThanOrEqual: { + MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]); + if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type)) + return NULL; + + gboolean is_unsigned = type_is_unsigned (fsig->params [0]); + MonoInst *ins = emit_xcompare (cfg, klass, arg0_type, args [0], args [1]); + switch (id) { + case SN_GreaterThan: + ins->inst_c0 = is_unsigned ? CMP_GT_UN : CMP_GT; + break; + case SN_GreaterThanOrEqual: + ins->inst_c0 = is_unsigned ? CMP_GE_UN : CMP_GE; + break; + case SN_LessThan: + ins->inst_c0 = is_unsigned ? CMP_LT_UN : CMP_LT; + break; + case SN_LessThanOrEqual: + ins->inst_c0 = is_unsigned ? CMP_LE_UN : CMP_LE; + break; + default: + g_assert_not_reached (); + } + return ins; + } case SN_Negate: case SN_OnesComplement: { #ifdef TARGET_ARM64 @@ -879,6 +958,8 @@ static guint16 vector64_vector128_t_methods [] = { SN_get_Count, SN_get_IsSupported, SN_get_Zero, + SN_op_Equality, + SN_op_Inequality, }; static MonoInst* @@ -928,10 +1009,10 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign return ins; } case SN_get_Zero: { - return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1); + return emit_xzero (cfg, klass); } case SN_get_AllBitsSet: { - MonoInst *ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1); + MonoInst *ins = emit_xzero (cfg, klass); return emit_xcompare (cfg, klass, etype->type, ins, ins); } case SN_Equals: { @@ -941,6 +1022,16 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign } break; } + case SN_op_Equality: + case SN_op_Inequality: + g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN && + mono_metadata_type_equal (fsig->params [0], type) && + mono_metadata_type_equal (fsig->params [1], type)); + switch (id) { + case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]); + case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]); + default: g_assert_not_reached (); + } default: break; } @@ -1086,7 +1177,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig return ins; case SN_get_Zero: g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type)); - return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1); + return emit_xzero (cfg, klass); case SN_get_One: { g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type)); MonoInst *one = NULL; @@ -1115,7 +1206,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig } case SN_get_AllBitsSet: { /* Compare a zero vector with itself */ - ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1); + ins = emit_xzero (cfg, klass); return emit_xcompare (cfg, klass, etype->type, ins, ins); } case SN_get_Item: { @@ -1222,14 +1313,11 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type)); - ins = emit_simd_ins (cfg, klass, OP_XEQUAL, args [0]->dreg, args [1]->dreg); - if (id == SN_op_Inequality) { - int sreg = ins->dreg; - int dreg = alloc_ireg (cfg); - MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0); - EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1); + switch (id) { + case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]); + case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]); + default: g_assert_not_reached (); } - return ins; case SN_GreaterThan: case SN_GreaterThanOrEqual: case SN_LessThan: diff --git a/src/mono/mono/mini/simd-methods.h b/src/mono/mono/mini/simd-methods.h index b33a383..161d9f8 100644 --- a/src/mono/mono/mini/simd-methods.h +++ b/src/mono/mono/mini/simd-methods.h @@ -62,6 +62,8 @@ METHOD(Create) METHOD(CreateScalar) METHOD(CreateScalarUnsafe) METHOD(ConditionalSelect) +METHOD(EqualsAll) +METHOD(EqualsAny) METHOD(GetElement) METHOD(GetLower) METHOD(GetUpper) -- 2.7.4