[Mono] Enable generating SIMD intrinsics for System.Numerics.Vector on Arm64 (#65486)
authorFan Yang <52458914+fanyang-mono@users.noreply.github.com>
Tue, 8 Mar 2022 15:20:59 +0000 (10:20 -0500)
committerGitHub <noreply@github.com>
Tue, 8 Mar 2022 15:20:59 +0000 (10:20 -0500)
* Enable SIMD intrinsics for System.Numerics.Vector on Arm64

* Minor adjustment

* Use the correct op code for BitwiseAnd, BitwiseOr and Xor

* Check if vector element type is a primitive type or not

* Remove dead code and fix constant formet

* Add type checks for each method and refactor

* Remove extra condition check

* Loose the type check for vector creation methods

* Remove type checks for Create* functions

* Remove some of the type checks

src/mono/mono/mini/llvm-intrinsics-types.h
src/mono/mono/mini/mini-llvm.c
src/mono/mono/mini/simd-intrinsics.c

index 401daab..9801d16 100644 (file)
@@ -18,10 +18,10 @@ typedef enum {
 } IntrinsicId;
 
 enum {
-       XBINOP_FORCEINT_and,
-       XBINOP_FORCEINT_or,
-       XBINOP_FORCEINT_ornot,
-       XBINOP_FORCEINT_xor,
+       XBINOP_FORCEINT_AND,
+       XBINOP_FORCEINT_OR,
+       XBINOP_FORCEINT_ORNOT,
+       XBINOP_FORCEINT_XOR,
 };
 
 #endif /* __MONO_MINI_LLVM_INTRINSICS_TYPES_H__ */
index 8d9204a..de64692 100644 (file)
@@ -7895,17 +7895,17 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
                        LLVMValueRef rhs_int = convert (ctx, rhs, intermediate_t);
                        LLVMValueRef result = NULL;
                        switch (ins->inst_c0) {
-                       case XBINOP_FORCEINT_and:
+                       case XBINOP_FORCEINT_AND:
                                result = LLVMBuildAnd (builder, lhs_int, rhs_int, "");
                                break;
-                       case XBINOP_FORCEINT_or:
+                       case XBINOP_FORCEINT_OR:
                                result = LLVMBuildOr (builder, lhs_int, rhs_int, "");
                                break;
-                       case XBINOP_FORCEINT_ornot:
+                       case XBINOP_FORCEINT_ORNOT:
                                result = LLVMBuildNot (builder, rhs_int, "");
                                result = LLVMBuildOr (builder, result, lhs_int, "");
                                break;
-                       case XBINOP_FORCEINT_xor:
+                       case XBINOP_FORCEINT_XOR:
                                result = LLVMBuildXor (builder, lhs_int, rhs_int, "");
                                break;
                        }
index 8e13c6f..4222b22 100644 (file)
@@ -600,13 +600,13 @@ static guint16 sri_vector_methods [] = {
        SN_AsUInt16,
        SN_AsUInt32,
        SN_AsUInt64,
-       SN_BitwiseAnd,
-       SN_BitwiseOr,
        SN_AsVector128,
        SN_AsVector2,
        SN_AsVector256,
        SN_AsVector3,
        SN_AsVector4,
+       SN_BitwiseAnd,
+       SN_BitwiseOr,
        SN_Ceiling,
        SN_ConditionalSelect,
        SN_ConvertToDouble,
@@ -669,24 +669,33 @@ is_create_from_half_vectors_overload (MonoMethodSignature *fsig)
        return mono_metadata_type_equal (fsig->params [0], fsig->params [1]);
 }
 
+static gboolean
+is_element_type_primitive (MonoType *vector_type)
+{
+       MonoType *element_type = get_vector_t_elem_type (vector_type);
+       return MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (element_type);
+}
+
 static MonoInst*
 emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsig, MonoInst **args)
 {
        if (!COMPILE_LLVM (cfg))
                return NULL;
 
-       MonoClass *klass = cmethod->klass;
        int id = lookup_intrins (sri_vector_methods, sizeof (sri_vector_methods), cmethod);
        if (id == -1)
                return NULL;
 
        if (!strcmp (m_class_get_name (cfg->method->klass), "Vector256"))
                return NULL; // TODO: Fix Vector256.WithUpper/WithLower
-
+       
+       MonoClass *klass = cmethod->klass;
        MonoTypeEnum arg0_type = fsig->param_count > 0 ? get_underlying_type (fsig->params [0]) : MONO_TYPE_VOID;
 
        switch (id) {
        case SN_Abs: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
 #ifdef TARGET_ARM64
                switch (arg0_type) {
                        case MONO_TYPE_U1:
@@ -704,16 +713,22 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
 #endif
 }
        case SN_Add:
+       case SN_Divide:
        case SN_Max:
        case SN_Min:
        case SN_Multiply:
        case SN_Subtract: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
                int instc0 = -1;
                if (arg0_type == MONO_TYPE_R4 || arg0_type == MONO_TYPE_R8) {
                        switch (id) {
                        case SN_Add:
                                instc0 = OP_FADD;
                                break;
+                       case SN_Divide:
+                               instc0 = OP_FDIV;
+                               break;
                        case SN_Max:
                                instc0 = OP_FMAX;
                                break;
@@ -734,6 +749,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
                        case SN_Add:
                                instc0 = OP_IADD;
                                break;
+                       case SN_Divide:
+                               return NULL;
                        case SN_Max:
                                instc0 = OP_IMAX;
                                break;
@@ -752,25 +769,34 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
                }
                return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, arg0_type, fsig, args);
        }
-       case SN_Divide: {
-               if ((arg0_type != MONO_TYPE_R4) && (arg0_type != MONO_TYPE_R8))
-                       return NULL;
-               return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FDIV, arg0_type, fsig, args);
-       }
        case SN_AndNot:
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
 #ifdef TARGET_ARM64
                return emit_simd_ins_for_sig (cfg, klass, OP_ARM64_BIC, -1, arg0_type, fsig, args);
 #else
                return NULL;
 #endif
        case SN_BitwiseAnd:
-               return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_IAND, arg0_type, fsig, args);
        case SN_BitwiseOr:
-               return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_IOR, arg0_type, fsig, args);
        case SN_Xor: {
-               if ((arg0_type == MONO_TYPE_R4) || (arg0_type == MONO_TYPE_R8))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
-               return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_IXOR, arg0_type, fsig, args);
+               int instc0 = -1;
+               switch (id) {
+               case SN_BitwiseAnd:
+                       instc0 = XBINOP_FORCEINT_AND;
+                       break;
+               case SN_BitwiseOr:
+                       instc0 = XBINOP_FORCEINT_OR;
+                       break;
+               case SN_Xor:
+                       instc0 = XBINOP_FORCEINT_XOR;
+                       break;
+               default:
+                       g_assert_not_reached ();
+               }
+               return emit_simd_ins_for_sig (cfg, klass, OP_XBINOP_FORCEINT, instc0, arg0_type, fsig, args);
        }
        case SN_As:
        case SN_AsByte:
@@ -783,9 +809,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        case SN_AsUInt16:
        case SN_AsUInt32:
        case SN_AsUInt64: {
-               MonoType *ret_type = get_vector_t_elem_type (fsig->ret);
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (ret_type) || !MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->ret) || !is_element_type_primitive (fsig->params [0]))
                        return NULL;
                return emit_simd_ins (cfg, klass, OP_XCAST, args [0]->dreg, -1);
        }
@@ -801,6 +825,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
 #endif
        }
        case SN_ConditionalSelect: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
 #ifdef TARGET_ARM64
                return emit_simd_ins_for_sig (cfg, klass, OP_ARM64_BSL, -1, arg0_type, fsig, args);
 #else
@@ -851,10 +877,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        case SN_Equals:
        case SN_EqualsAll:
        case SN_EqualsAny: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
-
                switch (id) {
                        case SN_Equals:
                                return emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
@@ -870,10 +894,10 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
                }
        }
        case SN_GetElement: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
                MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
                MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0];
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (etype))
-                       return NULL;
                int size = mono_class_value_size (arg_class, NULL);
                int esize = mono_class_value_size (mono_class_from_mono_type_internal (etype), NULL);
                int elems = size / esize;
@@ -884,8 +908,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        }
        case SN_GetLower:
        case SN_GetUpper: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
                int op = id == SN_GetLower ? OP_XLOWER : OP_XUPPER;
                return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
@@ -894,10 +917,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        case SN_GreaterThanOrEqual:
        case SN_LessThan:
        case SN_LessThanOrEqual: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
-
                gboolean is_unsigned = type_is_unsigned (fsig->params [0]);
                MonoInst *ins = emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
                switch (id) {
@@ -920,6 +941,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        }
        case SN_Negate:
        case SN_OnesComplement: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
 #ifdef TARGET_ARM64
                int op = id == SN_Negate ? OP_ARM64_XNEG : OP_ARM64_MVN;
                return emit_simd_ins_for_sig (cfg, klass, op, -1, arg0_type, fsig, args);
@@ -928,6 +951,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
 #endif
        }
        case SN_Sqrt: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
 #ifdef TARGET_ARM64
                if ((arg0_type != MONO_TYPE_R4) && (arg0_type != MONO_TYPE_R8))
                        return NULL;
@@ -937,25 +962,23 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
 #endif
        }
        case SN_ToScalar: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
                int extract_op = type_to_extract_op (arg0_type);
                return emit_simd_ins_for_sig (cfg, klass, extract_op, 0, arg0_type, fsig, args);
        }
        case SN_ToVector128:
        case SN_ToVector128Unsafe: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
                int op = id == SN_ToVector128 ? OP_XWIDEN : OP_XWIDEN_UNSAFE;
                return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
        }
        case SN_WithElement: {
+               if (!is_element_type_primitive (fsig->params [0]))
+                       return NULL;
                MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
                MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0];
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (etype))
-                       return NULL;
                int size = mono_class_value_size (arg_class, NULL);
                int esize = mono_class_value_size (mono_class_from_mono_type_internal (etype), NULL);
                int elems = size / esize;
@@ -969,8 +992,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
        }
        case SN_WithLower:
        case SN_WithUpper: {
-               MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
-               if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
+               if (!is_element_type_primitive (fsig->params [0]))
                        return NULL;
                int op = id == SN_GetLower ? OP_XINSERT_LOWER : OP_XINSERT_UPPER;
                return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
@@ -1559,7 +1581,7 @@ static SimdIntrinsic advsimd_methods [] = {
        {SN_AddScalar, OP_XBINOP_SCALAR, OP_IADD, None, None, OP_XBINOP_SCALAR, OP_FADD},
        {SN_AddWideningLower, OP_ARM64_SADD, None, OP_ARM64_UADD},
        {SN_AddWideningUpper, OP_ARM64_SADD2, None, OP_ARM64_UADD2},
-       {SN_And, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_and},
+       {SN_And, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_AND},
        {SN_BitwiseClear, OP_ARM64_BIC},
        {SN_BitwiseSelect, OP_ARM64_BSL},
        {SN_Ceiling, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FRINTP},
@@ -1762,8 +1784,8 @@ static SimdIntrinsic advsimd_methods [] = {
        {SN_NegateSaturateScalar, OP_XOP_OVR_SCALAR_X_X, INTRINS_AARCH64_ADV_SIMD_SQNEG},
        {SN_NegateScalar, OP_ARM64_XNEG_SCALAR},
        {SN_Not, OP_ARM64_MVN},
-       {SN_Or, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_or},
-       {SN_OrNot, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_ornot},
+       {SN_Or, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_OR},
+       {SN_OrNot, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_ORNOT},
        {SN_PolynomialMultiply, OP_XOP_OVR_X_X_X, INTRINS_AARCH64_ADV_SIMD_PMUL},
        {SN_PolynomialMultiplyWideningLower, OP_ARM64_PMULL},
        {SN_PolynomialMultiplyWideningUpper, OP_ARM64_PMULL2},
@@ -1883,7 +1905,7 @@ static SimdIntrinsic advsimd_methods [] = {
        {SN_UnzipOdd, OP_ARM64_UZP2},
        {SN_VectorTableLookup, OP_XOP_OVR_X_X_X, INTRINS_AARCH64_ADV_SIMD_TBL1},
        {SN_VectorTableLookupExtension, OP_XOP_OVR_X_X_X_X, INTRINS_AARCH64_ADV_SIMD_TBX1},
-       {SN_Xor, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_xor},
+       {SN_Xor, OP_XBINOP_FORCEINT, XBINOP_FORCEINT_XOR},
        {SN_ZeroExtendWideningLower, OP_ARM64_UXTL},
        {SN_ZeroExtendWideningUpper, OP_ARM64_UXTL2},
        {SN_ZipHigh, OP_ARM64_ZIP2},
@@ -3351,6 +3373,12 @@ mono_emit_simd_intrinsics (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
        }
 #endif // defined(TARGET_ARM64) || defined(TARGET_AMD64)
 
+#if defined(TARGET_ARM64)
+       if (!strcmp (class_ns, "System.Numerics") && !strcmp (class_name, "Vector")){
+               return emit_sri_vector (cfg, cmethod, fsig, args);
+       }
+#endif // defined(TARGET_ARM64)
+
        return emit_simd_intrinsics (class_ns, class_name, cfg, cmethod, fsig, args);
 }