[mono][jit] Adding support for Vector128::ExtractMostSignificantBits intrinsic on...
authorIvan Povazan <55002338+ivanpovazan@users.noreply.github.com>
Fri, 14 Apr 2023 15:59:36 +0000 (17:59 +0200)
committerGitHub <noreply@github.com>
Fri, 14 Apr 2023 15:59:36 +0000 (17:59 +0200)
Contributes to https://github.com/dotnet/runtime/issues/76025

src/mono/mono/arch/arm64/arm64-codegen.h
src/mono/mono/mini/cpu-arm64.mdesc
src/mono/mono/mini/mini-arm64.c
src/mono/mono/mini/mini-ops.h
src/mono/mono/mini/simd-intrinsics.c

index 2079685..2687b9f 100644 (file)
@@ -1804,6 +1804,7 @@ arm_encode_arith_imm (int imm, guint32 *shift)
 #define arm_neon_cmhi(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00110, (rd), (rn), (rm))
 #define arm_neon_cmhs(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00111, (rd), (rn), (rm))
 #define arm_neon_addp(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b0, (type), 0b10111, (rd), (rn), (rm))
+#define arm_neon_ushl(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b01000, (rd), (rn), (rm))
 
 // Generalized macros for float ops:
 //   width - determines if full register or its lower half is used one of {VREG_LOW, VREG_FULL}
index 7176dd6..8640097 100644 (file)
@@ -533,6 +533,8 @@ create_scalar_unsafe_int: dest:x src1:i len:4
 create_scalar_unsafe_float: dest:x src1:f len:4
 arm64_bic: dest:x src1:x src2:x len:4
 bitwise_select: dest:x src1:x src2:x src3:x len:12
+arm64_ushl: dest:x src1:x src2:x len:4
+arm64_ext_imm: dest:x src1:x src2:x len:4
 
 generic_class_init: src1:a len:44 clob:c
 gc_safe_point: src1:i len:12 clob:c
index a97237a..b4dfb33 100644 (file)
@@ -3920,8 +3920,18 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb)
                // case OP_XCONCAT:
                //      arm_neon_ext_16b(code, dreg, sreg1, sreg2, 8);
                //      break;
-               
-                       /* BRANCH */
+               case OP_ARM64_USHL: {
+                       arm_neon_ushl (code, get_vector_size_macro (ins), get_type_size_macro (ins->inst_c1), dreg, sreg1, sreg2);
+                       break;
+               }
+               case OP_ARM64_EXT_IMM: {
+                       if (get_vector_size_macro (ins) == VREG_LOW)
+                               arm_neon_ext_8b (code, dreg, sreg1, sreg2, ins->inst_c0);
+                       else
+                               arm_neon_ext_16b (code, dreg, sreg1, sreg2, ins->inst_c0);
+                       break;
+               }
+               /* BRANCH */
                case OP_BR:
                        mono_add_patch_info_rel (cfg, offset, MONO_PATCH_INFO_BB, ins->inst_target_bb, MONO_R_ARM64_B);
                        arm_b (code, code);
index 4f45bc3..9b70bdc 100644 (file)
@@ -1759,6 +1759,7 @@ MINI_OP(OP_ARM64_ABSCOMPARE, "arm64_abscompare", XREG, XREG, XREG)
 MINI_OP(OP_ARM64_XNARROW_SCALAR, "arm64_xnarrow_scalar", XREG, XREG, NONE)
 
 MINI_OP3(OP_ARM64_EXT, "arm64_ext", XREG, XREG, XREG, IREG)
+MINI_OP(OP_ARM64_EXT_IMM, "arm64_ext_imm", XREG, XREG, XREG)
 
 MINI_OP3(OP_ARM64_SQRDMLAH, "arm64_sqrdmlah", XREG, XREG, XREG, XREG)
 MINI_OP3(OP_ARM64_SQRDMLAH_BYSCALAR, "arm64_sqrdmlah_byscalar", XREG, XREG, XREG, XREG)
@@ -1775,6 +1776,8 @@ MINI_OP3(OP_ARM64_SQRDMLSH_SCALAR, "arm64_sqrdmlsh_scalar", XREG, XREG, XREG, XR
 MINI_OP(OP_ARM64_TBL_INDIRECT, "arm64_tbl_indirect", XREG, IREG, XREG)
 MINI_OP3(OP_ARM64_TBX_INDIRECT, "arm64_tbx_indirect", XREG, IREG, XREG, XREG)
 
+MINI_OP(OP_ARM64_USHL, "arm64_ushl", XREG, XREG, XREG)
+
 #endif // TARGET_ARM64
 
 MINI_OP(OP_SIMD_FCVTL, "simd_convert_to_higher_precision", XREG, XREG, NONE)
index cd4d498..f88c861 100644 (file)
@@ -1204,6 +1204,97 @@ is_element_type_primitive (MonoType *vector_type)
 }
 
 static MonoInst*
+emit_msb_vector_mask (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
+{
+       guint64 msb_mask_value[2];
+
+       switch (arg_type) {
+               case MONO_TYPE_I1:
+               case MONO_TYPE_U1:
+                       msb_mask_value[0] = 0x8080808080808080;
+                       msb_mask_value[1] = 0x8080808080808080;
+                       break;
+               case MONO_TYPE_I2:
+               case MONO_TYPE_U2:
+                       msb_mask_value[0] = 0x8000800080008000;
+                       msb_mask_value[1] = 0x8000800080008000;
+                       break;
+#if TARGET_SIZEOF_VOID_P == 4
+               case MONO_TYPE_I:
+               case MONO_TYPE_U:
+#endif
+               case MONO_TYPE_I4:
+               case MONO_TYPE_U4:
+               case MONO_TYPE_R4:
+                       msb_mask_value[0] = 0x8000000080000000;
+                       msb_mask_value[1] = 0x8000000080000000;
+                       break;
+#if TARGET_SIZEOF_VOID_P == 8
+               case MONO_TYPE_I:
+               case MONO_TYPE_U:
+#endif
+               case MONO_TYPE_I8:
+               case MONO_TYPE_U8:
+               case MONO_TYPE_R8:
+                       msb_mask_value[0] = 0x8000000000000000;
+                       msb_mask_value[1] = 0x8000000000000000;
+                       break;
+               default:
+                       g_assert_not_reached ();
+       }
+
+       MonoInst* msb_mask_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_mask_value);
+       msb_mask_vec->klass = arg_class;
+       return msb_mask_vec;
+}
+
+static MonoInst*
+emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
+{
+       guint64 msb_shift_value[2];
+
+       // NOTE: On ARM64 ushl shifts a vector left or right depending on the sign of the shift constant
+       switch (arg_type) {
+               case MONO_TYPE_I1:
+               case MONO_TYPE_U1:
+                       msb_shift_value[0] = 0x00FFFEFDFCFBFAF9;
+                       msb_shift_value[1] = 0x00FFFEFDFCFBFAF9;
+                       break;
+               case MONO_TYPE_I2:
+               case MONO_TYPE_U2:
+                       msb_shift_value[0] = 0xFFF4FFF3FFF2FFF1;
+                       msb_shift_value[1] = 0xFFF8FFF7FFF6FFF5;
+                       break;
+#if TARGET_SIZEOF_VOID_P == 4
+               case MONO_TYPE_I:
+               case MONO_TYPE_U:
+#endif
+               case MONO_TYPE_I4:
+               case MONO_TYPE_U4:
+               case MONO_TYPE_R4:
+                       msb_shift_value[0] = 0xFFFFFFE2FFFFFFE1;
+                       msb_shift_value[1] = 0xFFFFFFE4FFFFFFE3;
+                       break;
+#if TARGET_SIZEOF_VOID_P == 8
+               case MONO_TYPE_I:
+               case MONO_TYPE_U:
+#endif
+               case MONO_TYPE_I8:
+               case MONO_TYPE_U8:
+               case MONO_TYPE_R8:
+                       msb_shift_value[0] = 0xFFFFFFFFFFFFFFC1;
+                       msb_shift_value[1] = 0xFFFFFFFFFFFFFFC2;
+                       break;
+               default:
+                       g_assert_not_reached ();
+       }
+
+       MonoInst* msb_shift_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_shift_value);
+       msb_shift_vec->klass = arg_class;
+       return msb_shift_vec;
+}
+
+static MonoInst*
 emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsig, MonoInst **args)
 {      
 #if defined(TARGET_AMD64) || defined(TARGET_WASM)
@@ -1234,7 +1325,6 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
                case SN_ConvertToUInt64:
                case SN_Create:
                case SN_Dot:
-               case SN_ExtractMostSignificantBits:
                case SN_GetElement:
                case SN_GetLower:
                case SN_GetUpper:
@@ -1542,7 +1632,49 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
                        return NULL;
 #ifdef TARGET_WASM
                return emit_simd_ins_for_sig (cfg, klass, OP_WASM_SIMD_BITMASK, -1, -1, fsig, args);
-#else
+#elif defined(TARGET_ARM64)
+               if (COMPILE_LLVM (cfg))
+                       return NULL;
+
+               MonoInst* result_ins = NULL;
+               MonoClass* arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
+               int size = mono_class_value_size (arg_class, NULL);
+               if (size != 16)
+                       return NULL;
+
+               MonoInst* msb_mask_vec = emit_msb_vector_mask (cfg, arg_class, arg0_type);
+               MonoInst* and_res_vec = emit_simd_ins_for_binary_op (cfg, arg_class, fsig, args, arg0_type, SN_BitwiseAnd);
+               and_res_vec->sreg2 = msb_mask_vec->dreg;
+
+               MonoInst* msb_shift_vec = emit_msb_shift_vector_constant (cfg, arg_class, arg0_type);
+               MonoInst* shift_res_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_USHL, and_res_vec->dreg, msb_shift_vec->dreg);
+               shift_res_vec->inst_c1 = arg0_type;
+
+               if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1) {
+                       // Always perform usigned operations as vector sum and extract operations could sign-extend the result into the GP register
+                       // making the final result invalid. This is not needed for wider type as the maximum sum of extracted MSB cannot be larger than 8bits
+                       arg0_type = MONO_TYPE_U1;
+
+                       // In order to sum high and low 64bits of the shifted vector separatly, we use a zeroed vector and the extract operation
+                       MonoInst* zero_vec = emit_xzero(cfg, arg_class);
+
+                       MonoInst* ext_low_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, zero_vec->dreg, shift_res_vec->dreg);
+                       ext_low_vec->inst_c0 = 8;
+                       ext_low_vec->inst_c1 = arg0_type;
+                       MonoInst* sum_low_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_low_vec);
+
+                       MonoInst* ext_high_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, shift_res_vec->dreg, zero_vec->dreg);
+                       ext_high_vec->inst_c0 = 8;
+                       ext_high_vec->inst_c1 = arg0_type;
+                       MonoInst* sum_high_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_high_vec);
+
+                       MONO_EMIT_NEW_BIALU_IMM (cfg, OP_SHL_IMM, sum_high_vec->dreg, sum_high_vec->dreg, 8);
+                       EMIT_NEW_BIALU (cfg, result_ins, OP_IOR, sum_high_vec->dreg, sum_high_vec->dreg, sum_low_vec->dreg);
+               } else {
+                       result_ins = emit_sum_vector (cfg, fsig->params [0], arg0_type, shift_res_vec);
+               }
+               return result_ins;
+#elif defined(TARGET_AMD64)
                return NULL;
 #endif
        }