Adjust managed type system for new function pointer handling (#84819)
authorMichal Strehovský <MichalStrehovsky@users.noreply.github.com>
Tue, 18 Apr 2023 03:11:52 +0000 (12:11 +0900)
committerGitHub <noreply@github.com>
Tue, 18 Apr 2023 03:11:52 +0000 (12:11 +0900)
After #81006, the calling convention is no longer part of the type system identity of a function pointer type - it serves more like a modopt as far as the type system is concerned. The type system only cares whether the pointer is managed or not. Adjust the managed type system accordingly:

* If we're reading/representing a standalone method signature, read it as usual. Calling convention is available in flags/modopts.
* If we're reading/representing a function pointer type, collapse the calling convention information into the managed/unmanaged bit only.

src/coreclr/tools/Common/TypeSystem/Common/MethodDesc.cs
src/coreclr/tools/Common/TypeSystem/Common/TypeSystemContext.cs
src/coreclr/tools/Common/TypeSystem/Ecma/EcmaSignatureParser.cs
src/coreclr/tools/Common/TypeSystem/MetadataEmitter/TypeSystemMetadataEmitter.cs
src/coreclr/tools/aot/ILCompiler.TypeSystem.Tests/CoreTestAssembly/CoreTestAssembly.csproj
src/coreclr/tools/aot/ILCompiler.TypeSystem.Tests/CoreTestAssembly/Platform.cs
src/coreclr/tools/aot/ILCompiler.TypeSystem.Tests/CoreTestAssembly/VirtualFunctionOverride.cs
src/coreclr/tools/aot/ILCompiler.TypeSystem.Tests/VirtualFunctionOverrideTests.cs

index 7802c85d56e6a353a19c41a4455770a2a612a5bd..c0ba423ca927ca08ca0a5668eb3632d4a4eaf8a1 100644 (file)
@@ -27,7 +27,8 @@ namespace Internal.TypeSystem
     {
         RequiredCustomModifier = 0,
         OptionalCustomModifier = 1,
-        ArrayShape = 2
+        ArrayShape = 2,
+        UnmanagedCallConv = 3,
     }
 
     public struct EmbeddedSignatureData
index 7d2f6cfffd0e7863dbb7cd66992bbb3a4604d9b5..233fedde339368ec115d0a4c8aedaa38cd486b00 100644 (file)
@@ -242,6 +242,10 @@ namespace Internal.TypeSystem
 
         public FunctionPointerType GetFunctionPointerType(MethodSignature signature)
         {
+            // The type system only distinguishes between unmanaged and managed signatures.
+            // The caller should have normalized the signature by modifying flags and stripping modopts.
+            Debug.Assert((signature.Flags & MethodSignatureFlags.UnmanagedCallingConventionMask) is 0 or MethodSignatureFlags.UnmanagedCallingConvention);
+            Debug.Assert(!signature.HasEmbeddedSignatureData);
             return _functionPointerTypes.GetOrCreateValue(signature);
         }
 
index e49ce2768d0f76aba482d7dcd6645e4c681d358d..318ca38f800214b7abb366a404c4ca64ff23fddb 100644 (file)
@@ -371,7 +371,19 @@ namespace Internal.TypeSystem.Ecma
                 Debug.Assert((int)MethodSignatureFlags.CallingConventionVarargs == (int)SignatureCallingConvention.VarArgs);
                 Debug.Assert((int)MethodSignatureFlags.UnmanagedCallingConvention == (int)SignatureCallingConvention.Unmanaged);
 
-                flags = (MethodSignatureFlags)signatureCallConv;
+                // If skipEmbeddedSignatureData is true, we're building the signature for the purposes of building a type.
+                // We normalize unmanaged calling convention into a single value - "unmanaged".
+                if (skipEmbeddedSignatureData)
+                {
+                    flags = MethodSignatureFlags.UnmanagedCallingConvention;
+
+                    // But we still need to remember this signature is different, so add this to the EmbeddedSignatureData of the owner signature.
+                    _embeddedSignatureDataList?.Add(new EmbeddedSignatureData { index = string.Join(".", _indexStack) + "|" + ((int)signatureCallConv).ToString(), kind = EmbeddedSignatureDataKind.UnmanagedCallConv, type = null });
+                }
+                else
+                {
+                    flags = (MethodSignatureFlags)signatureCallConv;
+                }
             }
 
             if (!header.IsInstance)
index 217bdbe13a7e56f844f787bdccca95c68a2d5efb..7f465939fc3fc71f9d6c9d2b069262ff24678c04 100644 (file)
@@ -514,6 +514,27 @@ namespace Internal.TypeSystem
                 }
             }
 
+            public void UpdateSignatureCallingConventionAtCurrentIndexStack(ref SignatureCallingConvention callConv)
+            {
+                if (!Complete)
+                {
+                    if (_embeddedDataIndex < _embeddedData.Length)
+                    {
+                        if (_embeddedData[_embeddedDataIndex].kind == EmbeddedSignatureDataKind.UnmanagedCallConv)
+                        {
+                            string indexData = string.Join(".", _indexStack);
+
+                            var unmanagedCallConvPossibility = _embeddedData[_embeddedDataIndex].index.Split('|');
+                            if (unmanagedCallConvPossibility[0] == indexData)
+                            {
+                                callConv = (SignatureCallingConvention)int.Parse(unmanagedCallConvPossibility[1]);
+                                _embeddedDataIndex++;
+                            }
+                        }
+                    }
+                }
+            }
+
             public void EmitArrayShapeAtCurrentIndexStack(BlobBuilder signatureBuilder, int rank)
             {
                 var shapeEncoder = new ArrayShapeEncoder(signatureBuilder);
@@ -665,6 +686,9 @@ namespace Internal.TypeSystem
                     break;
             }
 
+            if (sigCallingConvention != SignatureCallingConvention.Default)
+                signatureDataEmitter.UpdateSignatureCallingConventionAtCurrentIndexStack(ref sigCallingConvention);
+
             signatureEncoder.MethodSignature(sigCallingConvention, genericParameterCount, isInstanceMethod);
             signatureBuilder.WriteCompressedInteger(sig.Length);
             EncodeType(signatureBuilder, sig.ReturnType, signatureDataEmitter);
index 871ffbcb5a204e809d31d46acc929967506a5a0e..90a8efbc32facb22d6a1fb25ce0f1c4224baf7fa 100644 (file)
@@ -4,6 +4,7 @@
     <AssemblyName>CoreTestAssembly</AssemblyName>
     <GenerateAssemblyInfo>false</GenerateAssemblyInfo>
     <IsCoreAssembly>true</IsCoreAssembly>
+    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
     <SkipTestRun>true</SkipTestRun>
     <TargetFramework>netstandard2.0</TargetFramework>
     <!-- Don't add references to the netstandard platform since this is a core assembly -->
index 75ade965f9318a301beeceb3593e25fc214d512b..c357cd1fb786696c826a55f9561077238c4e69b5 100644 (file)
@@ -60,6 +60,11 @@ namespace System
     public struct RuntimeMethodHandle { }
     public struct RuntimeFieldHandle { }
 
+    public class Type
+    {
+        public static Type GetTypeFromHandle(RuntimeTypeHandle handle) => null;
+    }
+
     public class Attribute { }
     public class AttributeUsageAttribute : Attribute
     {
@@ -159,9 +164,14 @@ namespace System.Runtime.CompilerServices
     {
     }
 
+    public class CallConvCdecl { }
+    public class CallConvStdcall { }
+    public class CallConvSuppressGCTransition { }
+
     public static class RuntimeFeature
     {
         public const string ByRefFields = nameof(ByRefFields);
+        public const string UnmanagedSignatureCallingConvention = nameof(UnmanagedSignatureCallingConvention);
         public const string VirtualStaticsInInterfaces = nameof(VirtualStaticsInInterfaces);
     }
 
index c237f561a266fc3b38a978c8723167e2df1f8e2d..8440dfc3cd00e44bcbe34db903b252e7724fcaa3 100644 (file)
@@ -42,4 +42,22 @@ namespace VirtualFunctionOverride
 
         }
     }
+
+    unsafe class FunctionPointerOverloadBase
+    {
+        // Do not reorder these, the test assumes this order
+        public virtual Type Method(delegate* unmanaged[Cdecl]<void> p) => typeof(delegate* unmanaged[Cdecl]<void>);
+        public virtual Type Method(delegate* unmanaged[Stdcall]<void> p) => typeof(delegate* unmanaged[Stdcall]<void>);
+        public virtual Type Method(delegate* unmanaged[Stdcall, SuppressGCTransition]<void> p) => typeof(delegate* unmanaged[Stdcall, SuppressGCTransition]<void>);
+        public virtual Type Method(delegate*<void> p) => typeof(delegate*<void>);
+    }
+
+    unsafe class FunctionPointerOverloadDerived : FunctionPointerOverloadBase
+    {
+        // Do not reorder these, the test assumes this order
+        public override Type Method(delegate* unmanaged[Cdecl]<void> p) => null;
+        public override Type Method(delegate* unmanaged[Stdcall]<void> p) => null;
+        public override Type Method(delegate* unmanaged[Stdcall, SuppressGCTransition]<void> p) => null;
+        public override Type Method(delegate*<void> p) => null;
+    }
 }
index e228993296ff59d534d8bb118688483aab9be5b1..d2f9d9576c63141c2d255ad40c088a54d95ce1bb 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Collections.Generic;
 using System.Linq;
 using Internal.TypeSystem;
 
@@ -268,5 +269,25 @@ namespace TypeSystemTests
             Assert.Contains("!0,!1", md1.Name);
             Assert.Contains("!1,!0", md2.Name);
         }
+
+        [Fact]
+        public void TestFunctionPointerOverloads()
+        {
+            MetadataType baseClass = _testModule.GetType("VirtualFunctionOverride", "FunctionPointerOverloadBase");
+            MetadataType derivedClass = _testModule.GetType("VirtualFunctionOverride", "FunctionPointerOverloadDerived");
+
+            var resolvedMethods = new List<MethodDesc>();
+            foreach (MethodDesc baseMethod in baseClass.GetVirtualMethods())
+                resolvedMethods.Add(derivedClass.FindVirtualFunctionTargetMethodOnObjectType(baseMethod));
+
+            var expectedMethods = new List<MethodDesc>();
+            foreach (MethodDesc derivedMethod in derivedClass.GetVirtualMethods())
+                expectedMethods.Add(derivedMethod);
+
+            Assert.Equal(expectedMethods, resolvedMethods);
+
+            Assert.Equal(expectedMethods[0].Signature[0], expectedMethods[1].Signature[0]);
+            Assert.NotEqual(expectedMethods[0].Signature[0], expectedMethods[3].Signature[0]);
+        }
     }
 }