Intrinsify Array GetArrayDataReference for SZ arrays (#87374)
authorMichał Petryka <35800402+MichalPetryka@users.noreply.github.com>
Mon, 17 Jul 2023 18:06:43 +0000 (20:06 +0200)
committerGitHub <noreply@github.com>
Mon, 17 Jul 2023 18:06:43 +0000 (20:06 +0200)
src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/MemoryMarshal.CoreCLR.cs
src/coreclr/jit/importercalls.cpp
src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/MemoryMarshal.NativeAot.cs
src/coreclr/vm/jitinterface.cpp
src/tests/JIT/Intrinsics/MemoryMarshalGetArrayDataReference.cs

index f4a87e5..fe0fed4 100644 (file)
@@ -33,6 +33,7 @@ namespace System.Runtime.InteropServices
         /// This technique does not perform array variance checks. The caller must manually perform any array variance checks
         /// if the caller wishes to write to the returned reference.
         /// </remarks>
+        [Intrinsic]
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public static ref byte GetArrayDataReference(Array array)
         {
index 7a304b1..df1df1c 100644 (file)
@@ -2762,14 +2762,33 @@ GenTree* Compiler::impIntrinsic(GenTree*                newobjThis,
             case NI_System_Runtime_InteropService_MemoryMarshal_GetArrayDataReference:
             {
                 assert(sig->numArgs == 1);
-                assert(sig->sigInst.methInstCount == 1);
 
-                GenTree*             array    = impPopStack().val;
-                CORINFO_CLASS_HANDLE elemHnd  = sig->sigInst.methInst[0];
-                CorInfoType          jitType  = info.compCompHnd->asCorInfoType(elemHnd);
-                var_types            elemType = JITtype2varType(jitType);
+                GenTree*             array   = impStackTop().val;
+                bool                 notNull = false;
+                CORINFO_CLASS_HANDLE elemHnd = NO_CLASS_HANDLE;
+                CorInfoType          jitType;
+                if (sig->sigInst.methInstCount == 1)
+                {
+                    elemHnd = sig->sigInst.methInst[0];
+                    jitType = info.compCompHnd->asCorInfoType(elemHnd);
+                }
+                else
+                {
+                    bool                 isExact  = false;
+                    CORINFO_CLASS_HANDLE arrayHnd = gtGetClassHandle(array, &isExact, &notNull);
+                    if ((arrayHnd == NO_CLASS_HANDLE) || !info.compCompHnd->isSDArray(arrayHnd))
+                    {
+                        return nullptr;
+                    }
+                    jitType = info.compCompHnd->getChildType(arrayHnd, &elemHnd);
+                }
+
+                array = impPopStack().val;
+
+                assert(jitType != CORINFO_TYPE_UNDEF);
+                assert((jitType != CORINFO_TYPE_VALUECLASS) || (elemHnd != NO_CLASS_HANDLE));
 
-                if (fgAddrCouldBeNull(array))
+                if (!notNull && fgAddrCouldBeNull(array))
                 {
                     GenTree* arrayClone;
                     array = impCloneExpr(array, &arrayClone, CHECK_SPILL_ALL,
@@ -2780,7 +2799,7 @@ GenTree* Compiler::impIntrinsic(GenTree*                newobjThis,
                 }
 
                 GenTree*          index     = gtNewIconNode(0, TYP_I_IMPL);
-                GenTreeIndexAddr* indexAddr = gtNewArrayIndexAddr(array, index, elemType, elemHnd);
+                GenTreeIndexAddr* indexAddr = gtNewArrayIndexAddr(array, index, JITtype2varType(jitType), elemHnd);
                 indexAddr->gtFlags &= ~GTF_INX_RNGCHK;
                 indexAddr->gtFlags |= GTF_INX_ADDR_NONNULL;
                 retNode = indexAddr;
index 19e4a2a..6169369 100644 (file)
@@ -34,6 +34,7 @@ namespace System.Runtime.InteropServices
         /// This technique does not perform array variance checks. The caller must manually perform any array variance checks
         /// if the caller wishes to write to the returned reference.
         /// </remarks>
+        [Intrinsic]
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public static ref byte GetArrayDataReference(Array array)
         {
index 840b2b9..b7dcf9e 100644 (file)
@@ -4737,7 +4737,7 @@ CorInfoType CEEInfo::getChildType (
 }
 
 /*********************************************************************/
-// Check if this is a single dimensional array type
+// Check if this is a single dimensional, zero based array type
 bool CEEInfo::isSDArray(CORINFO_CLASS_HANDLE  cls)
 {
     CONTRACTL {
index c59af97..d0d66ee 100644 (file)
@@ -114,6 +114,98 @@ namespace MemoryMarshalGetArrayDataReferenceTest
             ThrowsNRE(() => ref ptrByte(NoInline<byte[]>(null)));
             ThrowsNRE(() => ref ptrString(NoInline<string[]>(null)));
 
+            // use no inline methods to avoid indirect call inlining in the future
+            [MethodImpl(MethodImplOptions.NoInlining)]
+            static delegate*<Array, ref byte> GetMdPtr() => &MemoryMarshal.GetArrayDataReference;
+            delegate*<Array, ref byte> ptrMd = GetMdPtr();
+
+            IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference((Array)testByteArray), ref testByteArray[0]));
+            IsTrue(Unsafe.AreSame(ref ptrMd(testByteArray), ref testByteArray[0]));
+
+            IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference((Array)NoInline(testByteArray)), ref testByteArray[0]));
+            IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(testByteArray)), ref testByteArray[0]));
+            IsTrue(Unsafe.AreSame(ref ptrMd(NoInline(testByteArray)), ref testByteArray[0]));
+
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference((Array)testStringArray)), ref testStringArray[0]));
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref ptrMd(testStringArray)), ref testStringArray[0]));
+
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference((Array)NoInline(testStringArray))), ref testStringArray[0]));
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(testStringArray))), ref testStringArray[0]));
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref ptrMd(NoInline(testStringArray))), ref testStringArray[0]));
+
+            byte[,] testByteMdArray = new byte[1, 1];
+            IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference(testByteMdArray), ref testByteMdArray[0, 0]));
+            IsTrue(Unsafe.AreSame(ref ptrMd(testByteMdArray), ref testByteMdArray[0, 0]));
+
+            IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference(NoInline(testByteMdArray)), ref testByteMdArray[0, 0]));
+            IsTrue(Unsafe.AreSame(ref ptrMd(NoInline(testByteMdArray)), ref testByteMdArray[0, 0]));
+
+            string[,] testStringMdArray = new string[1, 1];
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference(testStringMdArray)), ref testStringMdArray[0, 0]));
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref ptrMd(testStringMdArray)), ref testStringMdArray[0, 0]));
+
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference(NoInline(testStringMdArray))), ref testStringMdArray[0, 0]));
+            IsTrue(Unsafe.AreSame(ref Unsafe.As<byte, string>(ref ptrMd(NoInline(testStringMdArray))), ref testStringMdArray[0, 0]));
+
+            Array nonZeroArray = Array.CreateInstance(typeof(string), new [] { 1 }, new [] { -1 });
+            string test = "test";
+            nonZeroArray.SetValue(test, -1);
+            IsTrue(ReferenceEquals(Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference(nonZeroArray)), test));
+            IsTrue(ReferenceEquals(Unsafe.As<byte, string>(ref ptrMd(nonZeroArray)), test));
+
+            IsTrue(ReferenceEquals(Unsafe.As<byte, string>(ref MemoryMarshal.GetArrayDataReference(NoInline(nonZeroArray))), test));
+            IsTrue(ReferenceEquals(Unsafe.As<byte, string>(ref ptrMd(NoInline(nonZeroArray))), test));
+
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference((Array)new byte[0])));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference((Array)new string[0])));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(new byte[0, 0])));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(new string[0, 0])));
+
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(new byte[0])));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(new string[0])));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(new byte[0, 0])));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(new string[0, 0])));
+
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference((Array)NoInline(new byte[0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference((Array)NoInline(new string[0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline(new byte[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline(new string[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(new byte[0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(new string[0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(new byte[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(new string[0, 0]))));
+
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline(new byte[0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline(new string[0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline(new byte[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline(new string[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline<Array>(new byte[0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline<Array>(new string[0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline<Array>(new byte[0, 0]))));
+            IsFalse(Unsafe.IsNullRef(ref ptrMd(NoInline<Array>(new string[0, 0]))));
+
+            ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference((Array)null); });
+            ThrowsNRE(() => { _ = ref ptrMd(null); });
+
+            ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference((Array)NoInline<byte[]>(null)); });
+            ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference((Array)NoInline<string[]>(null)); });
+            ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(null)); });
+
+            ThrowsNRE(() => { _ = ref ptrMd(NoInline<byte[]>(null)); });
+            ThrowsNRE(() => { _ = ref ptrMd(NoInline<string[]>(null)); });
+            ThrowsNRE(() => { _ = ref ptrMd(NoInline<Array>(null)); });
+
+            ThrowsNRE(() => ref MemoryMarshal.GetArrayDataReference((Array)null));
+            ThrowsNRE(() => ref ptrMd(null));
+
+            ThrowsNRE(() => ref MemoryMarshal.GetArrayDataReference((Array)NoInline<byte[]>(null)));
+            ThrowsNRE(() => ref MemoryMarshal.GetArrayDataReference((Array)NoInline<string[]>(null)));
+            ThrowsNRE(() => ref MemoryMarshal.GetArrayDataReference(NoInline<Array>(null)));
+
+            ThrowsNRE(() => ref ptrMd(NoInline<byte[]>(null)));
+            ThrowsNRE(() => ref ptrMd(NoInline<string[]>(null)));
+            ThrowsNRE(() => ref ptrMd(NoInline<Array>(null)));
+
             // from https://github.com/dotnet/runtime/issues/58312#issuecomment-993491291
             [MethodImpl(MethodImplOptions.NoInlining)]
             static int Problem1(StructWithByte[] a)