[mono] Implement Rdm and Dp (#49737)
authorimhameed <imhameed@microsoft.com>
Thu, 24 Jun 2021 20:56:27 +0000 (13:56 -0700)
committerGitHub <noreply@github.com>
Thu, 24 Jun 2021 20:56:27 +0000 (13:56 -0700)
See https://github.com/dotnet/runtime/issues/42322 and
https://github.com/dotnet/runtime/issues/42280.

Tested manually on an arm64 Linux VM running on an M1 Mac Mini. Does not
enable RDM or DP when AOT-compiling the associated runtime tests; our CI
hardware doesn't support these extensions yet.

src/mono/mono/mini/aot-compiler.c
src/mono/mono/mini/llvm-intrinsics.h
src/mono/mono/mini/mini-llvm.c
src/mono/mono/mini/mini-ops.h
src/mono/mono/mini/simd-intrinsics.c
src/mono/mono/mini/simd-methods.h

index b36fea0..3df6d4b 100644 (file)
@@ -8240,7 +8240,7 @@ parse_cpu_features (const gchar *attr)
        // if we disable a feature from the SSE-AVX tree we also need to disable all dependencies
        if (!enabled && (feature & MONO_CPU_X86_FULL_SSEAVX_COMBINED))
                feature = (MonoCPUFeatures) (MONO_CPU_X86_FULL_SSEAVX_COMBINED & ~feature);
-       
+
 #elif defined(TARGET_ARM64)
        // MONO_CPU_ARM64_BASE is unconditionally set in mini_get_cpu_features.
        if (!strcmp (attr + prefix, "crc"))
@@ -8249,6 +8249,10 @@ parse_cpu_features (const gchar *attr)
                feature = MONO_CPU_ARM64_CRYPTO;
        else if (!strcmp (attr + prefix, "neon"))
                feature = MONO_CPU_ARM64_NEON;
+       else if (!strcmp (attr + prefix, "rdm"))
+               feature = MONO_CPU_ARM64_RDM;
+       else if (!strcmp (attr + prefix, "dotprod"))
+               feature = MONO_CPU_ARM64_DP;
 #elif defined(TARGET_WASM)
        if (!strcmp (attr + prefix, "simd"))
                feature = MONO_CPU_WASM_SIMD;
index f12bb53..042f107 100644 (file)
@@ -25,6 +25,7 @@
 #define Widen INTRIN_kind_widen
 #define WidenAcross INTRIN_kind_widen_across
 #define Across INTRIN_kind_across
+#define Arm64DotProd INTRIN_kind_arm64_dot_prod
 #if !defined(Generic)
 #define Generic
 #endif
@@ -466,6 +467,10 @@ INTRINS_OVR_TAG(AARCH64_ADV_SIMD_SRI, aarch64_neon_vsri, Arm64, V64 | V128 | I1
 
 INTRINS_OVR_TAG(AARCH64_ADV_SIMD_TBX1, aarch64_neon_tbx1, Arm64, V64 | V128 | I1)
 INTRINS_OVR_TAG(AARCH64_ADV_SIMD_TBL1, aarch64_neon_tbl1, Arm64, V64 | V128 | I1)
+
+INTRINS_OVR_TAG_KIND(AARCH64_ADV_SIMD_SDOT, aarch64_neon_sdot, Arm64, Arm64DotProd, V64 | V128 | I4)
+INTRINS_OVR_TAG_KIND(AARCH64_ADV_SIMD_UDOT, aarch64_neon_udot, Arm64, Arm64DotProd, V64 | V128 | I4)
+
 #endif
 
 #undef INTRINS
@@ -486,6 +491,7 @@ INTRINS_OVR_TAG(AARCH64_ADV_SIMD_TBL1, aarch64_neon_tbl1, Arm64, V64 | V128 | I1
 #undef Ftoi
 #undef WidenAcross
 #undef Across
+#undef Arm64DotProd
 #undef Generic
 #undef X86
 #undef Arm64
index b19cf2f..7d56384 100644 (file)
@@ -360,6 +360,7 @@ enum {
        INTRIN_kind_widen,
        INTRIN_kind_widen_across,
        INTRIN_kind_across,
+       INTRIN_kind_arm64_dot_prod,
 };
 
 static const uint8_t intrin_kind [] = {
@@ -9660,6 +9661,21 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
                        values [ins->dreg] = result;
                        break;
                }
+               case OP_ARM64_SELECT_QUAD: {
+                       LLVMTypeRef src_type = simd_class_to_llvm_type (ctx, ins->data.op [1].klass);
+                       LLVMTypeRef ret_type = simd_class_to_llvm_type (ctx, ins->klass);
+                       unsigned int src_type_bits = mono_llvm_get_prim_size_bits (src_type);
+                       unsigned int ret_type_bits = mono_llvm_get_prim_size_bits (ret_type);
+                       unsigned int src_intermediate_elems = src_type_bits / 32;
+                       unsigned int ret_intermediate_elems = ret_type_bits / 32;
+                       LLVMTypeRef intermediate_type = LLVMVectorType (i4_t, src_intermediate_elems);
+                       LLVMValueRef result = LLVMBuildBitCast (builder, lhs, intermediate_type, "arm64_select_quad");
+                       result = LLVMBuildExtractElement (builder, result, rhs, "arm64_select_quad");
+                       result = broadcast_element (ctx, result, ret_intermediate_elems);
+                       result = LLVMBuildBitCast (builder, result, ret_type, "arm64_select_quad");
+                       values [ins->dreg] = result;
+                       break;
+               }
                case OP_LSCNT32:
                case OP_LSCNT64: {
                        // %shr = ashr i32 %x, 31
@@ -9683,6 +9699,43 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
                        values [ins->dreg] = LLVMBuildCall (builder, get_intrins (ctx, ins->opcode == OP_LSCNT32 ? INTRINS_CTLZ_I32 : INTRINS_CTLZ_I64), args, 2, "");
                        break;
                }
+               case OP_ARM64_SQRDMLAH:
+               case OP_ARM64_SQRDMLAH_BYSCALAR:
+               case OP_ARM64_SQRDMLAH_SCALAR:
+               case OP_ARM64_SQRDMLSH:
+               case OP_ARM64_SQRDMLSH_BYSCALAR:
+               case OP_ARM64_SQRDMLSH_SCALAR: {
+                       gboolean byscalar = FALSE;
+                       gboolean scalar = FALSE;
+                       gboolean subtract = FALSE;
+                       switch (ins->opcode) {
+                       case OP_ARM64_SQRDMLAH_BYSCALAR: byscalar = TRUE; break;
+                       case OP_ARM64_SQRDMLAH_SCALAR: scalar = TRUE; break;
+                       case OP_ARM64_SQRDMLSH: subtract = TRUE; break;
+                       case OP_ARM64_SQRDMLSH_BYSCALAR: subtract = TRUE; byscalar = TRUE; break;
+                       case OP_ARM64_SQRDMLSH_SCALAR: subtract = TRUE; scalar = TRUE; break;
+                       }
+                       int acc_iid = subtract ? INTRINS_AARCH64_ADV_SIMD_SQSUB : INTRINS_AARCH64_ADV_SIMD_SQADD;
+                       LLVMTypeRef ret_t = simd_class_to_llvm_type (ctx, ins->klass);
+                       llvm_ovr_tag_t ovr_tag = ovr_tag_from_llvm_type (ret_t);
+                       ScalarOpFromVectorOpCtx sctx = scalar_op_from_vector_op (ctx, ret_t, ins);
+                       LLVMValueRef args [] = { lhs, rhs, arg3 };
+                       if (byscalar) {
+                               unsigned int elems = LLVMGetVectorSize (ret_t);
+                               args [2] = broadcast_element (ctx, scalar_from_vector (ctx, args [2]), elems);
+                       }
+                       if (scalar) {
+                               ovr_tag = sctx.ovr_tag;
+                               scalar_op_from_vector_op_process_args (&sctx, args, 3);
+                       }
+                       LLVMValueRef result = call_overloaded_intrins (ctx, INTRINS_AARCH64_ADV_SIMD_SQRDMULH, ovr_tag, &args [1], "arm64_sqrdmlxh");
+                       args [1] = result;
+                       result = call_overloaded_intrins (ctx, acc_iid, ovr_tag, &args [0], "arm64_sqrdmlxh");
+                       if (scalar)
+                               result = scalar_op_from_vector_op_process_result (&sctx, result);
+                       values [ins->dreg] = result;
+                       break;
+               }
                case OP_ARM64_SMULH:
                case OP_ARM64_UMULH: {
                        LLVMValueRef op1, op2;
@@ -12136,6 +12189,13 @@ add_intrinsic (LLVMModuleRef module, int id)
                                                int associated_prim = MAX(ew, 2);
                                                LLVMTypeRef associated_scalar_type = intrin_types [0][associated_prim];
                                                intrins = add_intrins2 (module, id, associated_scalar_type, distinguishing_type);
+                                       } else if (kind == INTRIN_kind_arm64_dot_prod) {
+                                               /*
+                                                * @llvm.aarch64.neon.sdot.v2i32.v8i8
+                                                * @llvm.aarch64.neon.sdot.v4i32.v16i8
+                                                */
+                                                LLVMTypeRef associated_type = intrin_types [vw][0];
+                                                intrins = add_intrins2 (module, id, distinguishing_type, associated_type);
                                        } else
                                                intrins = add_intrins1 (module, id, distinguishing_type);
                                        int key = key_from_id_and_tag (id, test);
@@ -13530,9 +13590,11 @@ MonoCPUFeatures mono_llvm_get_cpu_features (void)
                { "bmi2",       MONO_CPU_X86_BMI2 },
 #endif
 #if defined(TARGET_ARM64)
-               { "crc",        MONO_CPU_ARM64_CRC },
-               { "crypto",     MONO_CPU_ARM64_CRYPTO },
-               { "neon",       MONO_CPU_ARM64_NEON }
+               { "crc",     MONO_CPU_ARM64_CRC },
+               { "crypto",  MONO_CPU_ARM64_CRYPTO },
+               { "neon",    MONO_CPU_ARM64_NEON },
+               { "rdm",     MONO_CPU_ARM64_RDM },
+               { "dotprod", MONO_CPU_ARM64_DP },
 #endif
 #if defined(TARGET_WASM)
                { "simd",       MONO_CPU_WASM_SIMD },
index 5fadd70..25dd6d8 100644 (file)
@@ -1740,6 +1740,7 @@ MINI_OP(OP_ARM64_UQXTN2, "arm64_uqxtn2", XREG, XREG, XREG)
 MINI_OP(OP_ARM64_SQXTUN2, "arm64_sqxtun2", XREG, XREG, XREG)
 
 MINI_OP(OP_ARM64_SELECT_SCALAR, "arm64_select_scalar", XREG, XREG, IREG)
+MINI_OP(OP_ARM64_SELECT_QUAD, "arm64_select_quad", XREG, XREG, IREG)
 
 MINI_OP(OP_ARM64_FCVTZU, "arm64_fcvtzu", XREG, XREG, NONE)
 MINI_OP(OP_ARM64_FCVTZS, "arm64_fcvtzs", XREG, XREG, NONE)
@@ -1807,4 +1808,11 @@ MINI_OP(OP_ARM64_XNARROW_SCALAR, "arm64_xnarrow_scalar", XREG, XREG, NONE)
 
 MINI_OP3(OP_ARM64_EXT, "arm64_ext", XREG, XREG, XREG, IREG)
 
+MINI_OP3(OP_ARM64_SQRDMLAH, "arm64_sqrdmlah", XREG, XREG, XREG, XREG)
+MINI_OP3(OP_ARM64_SQRDMLAH_BYSCALAR, "arm64_sqrdmlah_byscalar", XREG, XREG, XREG, XREG)
+MINI_OP3(OP_ARM64_SQRDMLAH_SCALAR, "arm64_sqrdmlah_scalar", XREG, XREG, XREG, XREG)
+MINI_OP3(OP_ARM64_SQRDMLSH, "arm64_sqrdmlsh", XREG, XREG, XREG, XREG)
+MINI_OP3(OP_ARM64_SQRDMLSH_BYSCALAR, "arm64_sqrdmlsh_byscalar", XREG, XREG, XREG, XREG)
+MINI_OP3(OP_ARM64_SQRDMLSH_SCALAR, "arm64_sqrdmlsh_scalar", XREG, XREG, XREG, XREG)
+
 #endif // TARGET_ARM64
index b65d8d8..9b4fce5 100644 (file)
@@ -1428,13 +1428,31 @@ static SimdIntrinsic advsimd_methods [] = {
        {SN_get_IsSupported},
 };
 
+static const SimdIntrinsic rdm_methods [] = {
+       {SN_MultiplyRoundedDoublingAndAddSaturateHigh, OP_ARM64_SQRDMLAH},
+       {SN_MultiplyRoundedDoublingAndAddSaturateHighScalar, OP_ARM64_SQRDMLAH_SCALAR},
+       {SN_MultiplyRoundedDoublingAndSubtractSaturateHigh, OP_ARM64_SQRDMLSH},
+       {SN_MultiplyRoundedDoublingAndSubtractSaturateHighScalar, OP_ARM64_SQRDMLSH_SCALAR},
+       {SN_MultiplyRoundedDoublingBySelectedScalarAndAddSaturateHigh},
+       {SN_MultiplyRoundedDoublingBySelectedScalarAndSubtractSaturateHigh},
+       {SN_MultiplyRoundedDoublingScalarBySelectedScalarAndAddSaturateHigh},
+       {SN_MultiplyRoundedDoublingScalarBySelectedScalarAndSubtractSaturateHigh},
+       {SN_get_IsSupported},
+};
+
+static const SimdIntrinsic dp_methods [] = {
+       {SN_DotProduct, OP_XOP_OVR_X_X_X_X, INTRINS_AARCH64_ADV_SIMD_SDOT, OP_XOP_OVR_X_X_X_X, INTRINS_AARCH64_ADV_SIMD_UDOT},
+       {SN_DotProductBySelectedQuadruplet},
+       {SN_get_IsSupported},
+};
+
 static const IntrinGroup supported_arm_intrinsics [] = {
        { "AdvSimd", MONO_CPU_ARM64_NEON, advsimd_methods, sizeof (advsimd_methods) },
        { "Aes", MONO_CPU_ARM64_CRYPTO, crypto_aes_methods, sizeof (crypto_aes_methods) },
        { "ArmBase", MONO_CPU_ARM64_BASE, armbase_methods, sizeof (armbase_methods) },
        { "Crc32", MONO_CPU_ARM64_CRC, crc32_methods, sizeof (crc32_methods) },
-       { "Dp", MONO_CPU_ARM64_DP, unsupported, sizeof (unsupported) },
-       { "Rdm", MONO_CPU_ARM64_RDM, unsupported, sizeof (unsupported) },
+       { "Dp", MONO_CPU_ARM64_DP, dp_methods, sizeof (dp_methods) },
+       { "Rdm", MONO_CPU_ARM64_RDM, rdm_methods, sizeof (rdm_methods) },
        { "Sha1", MONO_CPU_ARM64_CRYPTO, sha1_methods, sizeof (sha1_methods) },
        { "Sha256", MONO_CPU_ARM64_CRYPTO, sha256_methods, sizeof (sha256_methods) },
 };
@@ -1740,6 +1758,51 @@ emit_arm64_intrinsics (
                }
        }
 
+       if (feature == MONO_CPU_ARM64_RDM) {
+               switch (id) {
+               case SN_MultiplyRoundedDoublingBySelectedScalarAndAddSaturateHigh:
+               case SN_MultiplyRoundedDoublingBySelectedScalarAndSubtractSaturateHigh:
+               case SN_MultiplyRoundedDoublingScalarBySelectedScalarAndAddSaturateHigh:
+               case SN_MultiplyRoundedDoublingScalarBySelectedScalarAndSubtractSaturateHigh: {
+                       MonoClass *ret_klass = mono_class_from_mono_type_internal (fsig->ret);
+                       int opcode = 0;
+                       switch (id) {
+                       case SN_MultiplyRoundedDoublingBySelectedScalarAndAddSaturateHigh: opcode = OP_ARM64_SQRDMLAH_BYSCALAR; break;
+                       case SN_MultiplyRoundedDoublingBySelectedScalarAndSubtractSaturateHigh: opcode = OP_ARM64_SQRDMLSH_BYSCALAR; break;
+                       case SN_MultiplyRoundedDoublingScalarBySelectedScalarAndAddSaturateHigh: opcode = OP_ARM64_SQRDMLAH_SCALAR; break;
+                       case SN_MultiplyRoundedDoublingScalarBySelectedScalarAndSubtractSaturateHigh: opcode = OP_ARM64_SQRDMLSH_SCALAR; break;
+                       }
+                       MonoInst *scalar = emit_simd_ins (cfg, ret_klass, OP_ARM64_SELECT_SCALAR, args [2]->dreg, args [3]->dreg);
+                       MonoInst *ret = emit_simd_ins (cfg, ret_klass, opcode, args [0]->dreg, args [1]->dreg);
+                       ret->inst_c1 = arg0_type;
+                       ret->sreg3 = scalar->dreg;
+                       return ret;
+               }
+               default:
+                       g_assert_not_reached ();
+               }
+       }
+
+       if (feature == MONO_CPU_ARM64_DP) {
+               switch (id) {
+               case SN_DotProductBySelectedQuadruplet: {
+                       MonoClass *ret_klass = mono_class_from_mono_type_internal (fsig->ret);
+                       MonoClass *arg_klass = mono_class_from_mono_type_internal (fsig->params [1]);
+                       MonoClass *quad_klass = mono_class_from_mono_type_internal (fsig->params [2]);
+                       gboolean is_unsigned = type_is_unsigned (fsig->ret);
+                       int iid = is_unsigned ? INTRINS_AARCH64_ADV_SIMD_UDOT : INTRINS_AARCH64_ADV_SIMD_SDOT;
+                       MonoInst *quad = emit_simd_ins (cfg, arg_klass, OP_ARM64_SELECT_QUAD, args [2]->dreg, args [3]->dreg);
+                       quad->data.op [1].klass = quad_klass;
+                       MonoInst *ret = emit_simd_ins (cfg, ret_klass, OP_XOP_OVR_X_X_X_X, args [0]->dreg, args [1]->dreg);
+                       ret->sreg3 = quad->dreg;
+                       ret->inst_c0 = iid;
+                       return ret;
+               }
+               default:
+                       g_assert_not_reached ();
+               }
+       }
+
        return NULL;
 }
 #endif // TARGET_ARM64
index 673543d..4abd0eb 100644 (file)
@@ -248,7 +248,7 @@ METHOD(HashUpdate2)
 METHOD(ScheduleUpdate0)
 METHOD(ScheduleUpdate1)
 METHOD(MixColumns)
-//AdvSimd
+// AdvSimd
 METHOD(AbsSaturate)
 METHOD(AbsSaturateScalar)
 METHOD(AbsScalar)
@@ -559,3 +559,15 @@ METHOD(ZeroExtendWideningLower)
 METHOD(ZeroExtendWideningUpper)
 METHOD(ZipHigh)
 METHOD(ZipLow)
+// Arm.Rdm
+METHOD(MultiplyRoundedDoublingAndAddSaturateHigh)
+METHOD(MultiplyRoundedDoublingAndSubtractSaturateHigh)
+METHOD(MultiplyRoundedDoublingBySelectedScalarAndAddSaturateHigh)
+METHOD(MultiplyRoundedDoublingBySelectedScalarAndSubtractSaturateHigh)
+// Arm.Rdm.Arm64
+METHOD(MultiplyRoundedDoublingAndAddSaturateHighScalar)
+METHOD(MultiplyRoundedDoublingAndSubtractSaturateHighScalar)
+METHOD(MultiplyRoundedDoublingScalarBySelectedScalarAndAddSaturateHigh)
+METHOD(MultiplyRoundedDoublingScalarBySelectedScalarAndSubtractSaturateHigh)
+// Arm.Dp
+METHOD(DotProductBySelectedQuadruplet)