Clean up some code after the interface inheritance work (#86347)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Tue, 16 May 2023 21:40:58 +0000 (14:40 -0700)
committerGitHub <noreply@github.com>
Tue, 16 May 2023 21:40:58 +0000 (14:40 -0700)
20 files changed:
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs [deleted file]
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs [deleted file]
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ComInterfaceDispatchMarshallerFactory.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnreachableException.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGeneratorHelpers.cs

diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs
deleted file mode 100644 (file)
index 8443100..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System.Linq;
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-
-namespace Microsoft.Interop
-{
-    /// <summary>
-    /// Provides the info necessary for copying an attribute from user code to generated code.
-    /// </summary>
-    internal sealed record AttributeInfo(ManagedTypeInfo Type, SequenceEqualImmutableArray<string> Arguments)
-    {
-        internal static AttributeInfo From(AttributeData attribute)
-        {
-            var type = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(attribute.AttributeClass);
-            var args = attribute.ConstructorArguments.Select(ca => ca.ToCSharpString());
-            return new(type, args.ToSequenceEqualImmutableArray());
-        }
-    }
-}
index 82f1e2a..dd6797f 100644 (file)
@@ -7,26 +7,23 @@ using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop
 {
-    public sealed partial class ComInterfaceGenerator
+    /// <summary>
+    /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
+    /// </summary>
+    internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
     {
+        // Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
+        // Have a step that runs CalculateMethodStub on each of them.
+        // Call GroupMethodsByInterfaceForGeneration
+
         /// <summary>
-        /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
+        /// COM methods that are declared on the attributed interface declaration.
         /// </summary>
-        private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
-        {
-            // Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
-            // Have a step that runs CalculateMethodStub on each of them.
-            // Call GroupMethodsByInterfaceForGeneration
-
-            /// <summary>
-            /// COM methods that are declared on the attributed interface declaration.
-            /// </summary>
-            public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);
+        public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);
 
-            /// <summary>
-            /// COM methods that are declared on an interface the interface inherits from.
-            /// </summary>
-            public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
-        }
+        /// <summary>
+        /// COM methods that are declared on an interface the interface inherits from.
+        /// </summary>
+        public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
     }
 }
index 5b570c1..0eee960 100644 (file)
@@ -7,51 +7,48 @@ using System.Threading;
 
 namespace Microsoft.Interop
 {
-    public sealed partial class ComInterfaceGenerator
+    internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
     {
-        private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
+        /// <summary>
+        /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
+        /// </summary>
+        public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
         {
-            /// <summary>
-            /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
-            /// </summary>
-            public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
+            Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
+            var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
+            foreach (var iface in data)
             {
-                Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
-                var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
-                foreach (var iface in data)
+                symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
+            }
+            Dictionary<string, ComInterfaceContext> symbolToContextMap = new();
+
+            foreach (var iface in data)
+            {
+                accumulator.Add(AddContext(iface));
+            }
+            return accumulator.MoveToImmutable();
+
+            ComInterfaceContext AddContext(ComInterfaceInfo iface)
+            {
+                if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
                 {
-                    symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
+                    return cachedValue;
                 }
-                Dictionary<string, ComInterfaceContext> symbolToContextMap = new();
 
-                foreach (var iface in data)
+                if (iface.BaseInterfaceKey is null)
                 {
-                    accumulator.Add(AddContext(iface));
+                    var baselessCtx = new ComInterfaceContext(iface, null);
+                    symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
+                    return baselessCtx;
                 }
-                return accumulator.MoveToImmutable();
 
-                ComInterfaceContext AddContext(ComInterfaceInfo iface)
+                if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
                 {
-                    if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
-                    {
-                        return cachedValue;
-                    }
-
-                    if (iface.BaseInterfaceKey is null)
-                    {
-                        var baselessCtx = new ComInterfaceContext(iface, null);
-                        symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
-                        return baselessCtx;
-                    }
-
-                    if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
-                    {
-                        baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
-                    }
-                    var ctx = new ComInterfaceContext(iface, baseContext);
-                    symbolToContextMap[iface.ThisInterfaceKey] = ctx;
-                    return ctx;
+                    baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
                 }
+                var ctx = new ComInterfaceContext(iface, baseContext);
+                symbolToContextMap[iface.ThisInterfaceKey] = ctx;
+                return ctx;
             }
         }
     }
index 85f01c4..f2d6c30 100644 (file)
@@ -2,9 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Collections.Immutable;
-using System.Collections.Specialized;
 using System.IO;
 using System.Linq;
 using System.Reflection;
@@ -19,14 +17,6 @@ namespace Microsoft.Interop
     [Generator]
     public sealed partial class ComInterfaceGenerator : IIncrementalGenerator
     {
-        private sealed record class GeneratedStubCodeContext(
-            ManagedTypeInfo OriginalDefiningType,
-            ContainingSyntaxContext ContainingSyntaxContext,
-            SyntaxEquivalentNode<MethodDeclarationSyntax> Stub,
-            SequenceEqualImmutableArray<Diagnostic> Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics);
-
-        private sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray<Diagnostic>.Empty));
-
         public static class StepNames
         {
             public const string CalculateStubInformation = nameof(CalculateStubInformation);
@@ -103,11 +93,9 @@ namespace Microsoft.Interop
                 {
                     var ((data, symbolMap), env) = param;
                     return new ComMethodContext(
-                        data.Method.OriginalDeclaringInterface,
-                        data.TypeKeyOwner,
-                        data.Method.MethodInfo,
-                        data.Method.Index,
-                        CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct));
+                        data.Method,
+                        data.OwningInterface,
+                        CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info.Type, ct));
                 }).WithTrackingName(StepNames.CalculateStubInformation);
 
             var interfaceAndMethodsContexts = comMethodContexts
@@ -117,7 +105,7 @@ namespace Microsoft.Interop
 
             // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation.
             context.RegisterDiagnostics(interfaceAndMethodsContexts
-                .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetManagedToUnmanagedStub().Diagnostics)));
+                .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics)));
             var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
                 .Select(GenerateImplementationInterface)
                 .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation)
@@ -126,7 +114,7 @@ namespace Microsoft.Interop
 
             // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation.
             context.RegisterDiagnostics(interfaceAndMethodsContexts
-                .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetNativeToManagedStub().Diagnostics)));
+                .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics)));
             var nativeToManagedVtableMethods = interfaceAndMethodsContexts
                 .Select(GenerateImplementationVTableMethods)
                 .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods)
@@ -145,11 +133,11 @@ namespace Microsoft.Interop
                 .Select((data, ct) =>
                 {
                     var context = data.Interface.Info;
-                    var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow());
+                    var methods = data.ShadowingMethods.Select(m => m.Shadow);
                     var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier)
                         .WithModifiers(context.ContainingSyntax.Modifiers)
                         .WithTypeParameterList(context.ContainingSyntax.TypeParameters)
-                        .WithMembers(List(methods));
+                        .WithMembers(List<MemberDeclarationSyntax>(methods));
                     return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl);
                 })
                 .WithTrackingName(StepNames.GenerateShadowingMethods)
@@ -211,33 +199,6 @@ namespace Microsoft.Interop
             });
         }
 
-        private static string GenerateMarkerInterfaceSource(ComInterfaceInfo iface) => $$"""
-            file unsafe class InterfaceInformation : global::System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType
-            {
-                public static global::System.Guid Iid => new(new global::System.ReadOnlySpan<byte>(new byte[] { {{string.Join(",", iface.InterfaceId.ToByteArray())}} }));
-
-                private static void** m_vtable;
-
-                public static void** ManagedVirtualMethodTable
-                {
-                    get
-                    {
-                        if (m_vtable == null)
-                        {
-                            nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.Type.FullTypeName}}), sizeof(nint) * 3);
-                            global::System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out vtable[0], out vtable[1], out vtable[2]);
-                            m_vtable = (void**)vtable;
-                        }
-                        return m_vtable;
-                    }
-                }
-            }
-
-            [global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementation]
-            file interface InterfaceImplementation : {{iface.Type.FullTypeName}}
-            {}
-            """;
-
         private static readonly AttributeSyntax s_iUnknownDerivedAttributeTemplate =
             Attribute(
                 GenericName(TypeNames.IUnknownDerivedAttribute)
@@ -252,8 +213,7 @@ namespace Microsoft.Interop
                     .WithTypeParameterList(context.ContainingSyntax.TypeParameters)
                     .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate))));
 
-        // Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here
-        private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo typeKeyOwner, CancellationToken ct)
+        private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct)
         {
             ct.ThrowIfCancellationRequested();
             INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
@@ -366,7 +326,7 @@ namespace Microsoft.Interop
                 new ComExceptionMarshalling(),
                 ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
                 ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
-                typeKeyOwner,
+                owningInterface,
                 declaringType,
                 generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(),
                 ComInterfaceDispatchMarshallingInfo.Instance);
@@ -413,31 +373,32 @@ namespace Microsoft.Interop
         private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
         {
             var definingType = interfaceGroup.Interface.Info.Type;
-            var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.GetManagedToUnmanagedStub()))
+            var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
                 .Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
                 .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
                 .WithExplicitInterfaceSpecifier(
                     ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
-            var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.GenerateUnreachableExceptionStub());
+            var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
             return ImplementationInterfaceTemplate
                 .AddBaseListTypes(SimpleBaseType(definingType.Syntax))
                 .WithMembers(
                     List<MemberDeclarationSyntax>(
                         interfaceGroup.DeclaredMethods
-                        .Select(m => m.GetManagedToUnmanagedStub())
+                        .Select(m => m.ManagedToUnmanagedStub)
                         .OfType<GeneratedStubCodeContext>()
                         .Select(ctx => ctx.Stub.Node)
                         .Concat(shadowImplementations)
                         .Concat(inheritedStubs)))
                 .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute)))));
         }
+
         private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _)
         {
             return ImplementationInterfaceTemplate
                 .WithMembers(
                     List<MemberDeclarationSyntax>(
                         comInterfaceAndMethods.DeclaredMethods
-                            .Select(m => m.GetNativeToManagedStub())
+                            .Select(m => m.UnmanagedToManagedStub)
                             .OfType<GeneratedStubCodeContext>()
                             .Select(context => context.Stub.Node)));
         }
@@ -448,6 +409,7 @@ namespace Microsoft.Interop
 
         private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName)
             .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword));
+
         private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _)
         {
             const string vtableLocalName = "vtable";
index e192fcd..8853d90 100644 (file)
@@ -2,9 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Linq;
-using System.Text;
 using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop
index cd747bc..dd81de2 100644 (file)
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
 using System.Linq;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
-using Roslyn.Utilities;
 
 namespace Microsoft.Interop
 {
-    public sealed partial class ComInterfaceGenerator
+    /// <summary>
+    /// Information about a Com interface, but not its methods.
+    /// </summary>
+    internal sealed record ComInterfaceInfo(
+        ManagedTypeInfo Type,
+        string ThisInterfaceKey, // For associating interfaces to its base
+        string? BaseInterfaceKey, // For associating interfaces to its base
+        InterfaceDeclarationSyntax Declaration,
+        ContainingSyntaxContext TypeDefinitionContext,
+        ContainingSyntax ContainingSyntax,
+        Guid InterfaceId)
     {
-        /// <summary>
-        /// Information about a Com interface, but not its methods.
-        /// </summary>
-        private sealed record ComInterfaceInfo(
-            ManagedTypeInfo Type,
-            string ThisInterfaceKey, // For associating interfaces to its base
-            string? BaseInterfaceKey, // For associating interfaces to its base
-            InterfaceDeclarationSyntax Declaration,
-            ContainingSyntaxContext TypeDefinitionContext,
-            ContainingSyntax ContainingSyntax,
-            Guid InterfaceId)
+        public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax)
         {
-            public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax)
+            // Verify the method has no generic types or defined implementation
+            // and is not marked static or sealed
+            if (syntax.TypeParameterList is not null)
+            {
+                return (null, Diagnostic.Create(
+                    GeneratorDiagnostics.InvalidAttributedMethodSignature,
+                    syntax.Identifier.GetLocation(),
+                    symbol.Name));
+            }
+
+            // Verify that the types the method is declared in are marked partial.
+            for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
             {
-                // Verify the method has no generic types or defined implementation
-                // and is not marked static or sealed
-                if (syntax.TypeParameterList is not null)
+                if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
                 {
                     return (null, Diagnostic.Create(
-                        GeneratorDiagnostics.InvalidAttributedMethodSignature,
+                        GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers,
                         syntax.Identifier.GetLocation(),
-                        symbol.Name));
-                }
-
-                // Verify that the types the method is declared in are marked partial.
-                for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
-                {
-                    if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
-                    {
-                        return (null, Diagnostic.Create(
-                            GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers,
-                            syntax.Identifier.GetLocation(),
-                            symbol.Name,
-                            typeDecl.Identifier));
-                    }
+                        symbol.Name,
+                        typeDecl.Identifier));
                 }
+            }
 
-                if (!TryGetGuid(symbol, syntax, out Guid? guid, out Diagnostic? guidDiagnostic))
-                    return (null, guidDiagnostic);
+            if (!TryGetGuid(symbol, syntax, out Guid? guid, out Diagnostic? guidDiagnostic))
+                return (null, guidDiagnostic);
 
-                if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic))
-                    return (null, baseDiagnostic);
+            if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic))
+                return (null, baseDiagnostic);
 
-                return (new ComInterfaceInfo(
-                    ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol),
-                    symbol.ToDisplayString(),
-                    baseSymbol?.ToDisplayString(),
-                    syntax,
-                    new ContainingSyntaxContext(syntax),
-                    new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
-                    guid ?? Guid.Empty), null);
-            }
+            return (new ComInterfaceInfo(
+                ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol),
+                symbol.ToDisplayString(),
+                baseSymbol?.ToDisplayString(),
+                syntax,
+                new ContainingSyntaxContext(syntax),
+                new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
+                guid ?? Guid.Empty), null);
+        }
 
-            /// <summary>
-            /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets <paramref name="diagnostic"/>.
-            /// </summary>
-            private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic)
+        /// <summary>
+        /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets <paramref name="diagnostic"/>.
+        /// </summary>
+        private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic)
+        {
+            baseComIface = null;
+            foreach (var implemented in comIface.Interfaces)
             {
-                baseComIface = null;
-                foreach (var implemented in comIface.Interfaces)
+                foreach (var attr in implemented.GetAttributes())
                 {
-                    foreach (var attr in implemented.GetAttributes())
+                    if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
                     {
-                        if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
+                        if (baseComIface is not null)
                         {
-                            if (baseComIface is not null)
-                            {
-                                diagnostic = Diagnostic.Create(
-                                    GeneratorDiagnostics.MultipleComInterfaceBaseTypes,
-                                    syntax.Identifier.GetLocation(),
-                                    comIface.ToDisplayString());
-                                return false;
-                            }
-                            baseComIface = implemented;
+                            diagnostic = Diagnostic.Create(
+                                GeneratorDiagnostics.MultipleComInterfaceBaseTypes,
+                                syntax.Identifier.GetLocation(),
+                                comIface.ToDisplayString());
+                            return false;
                         }
+                        baseComIface = implemented;
                     }
                 }
-                diagnostic = null;
-                return true;
             }
+            diagnostic = null;
+            return true;
+        }
 
-            /// <summary>
-            /// Returns true and sets <paramref name="guid"/> if the guid is present. Returns false and sets diagnostic if the guid is not present or is invalid.
-            /// </summary>
-            private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out Guid? guid, [NotNullWhen(false)] out Diagnostic? diagnostic)
+        /// <summary>
+        /// Returns true and sets <paramref name="guid"/> if the guid is present. Returns false and sets diagnostic if the guid is not present or is invalid.
+        /// </summary>
+        private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out Guid? guid, [NotNullWhen(false)] out Diagnostic? diagnostic)
+        {
+            guid = null;
+            AttributeData? guidAttr = null;
+            AttributeData? _ = null; // Interface Attribute Type. We'll always assume IUnkown for now.
+            foreach (var attr in interfaceSymbol.GetAttributes())
             {
-                guid = null;
-                AttributeData? guidAttr = null;
-                AttributeData? _ = null; // Interface Attribute Type. We'll always assume IUnkown for now.
-                foreach (var attr in interfaceSymbol.GetAttributes())
-                {
-                    var attrDisplayString = attr.AttributeClass?.ToDisplayString();
-                    if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute)
-                        guidAttr = attr;
-                    else if (attrDisplayString is TypeNames.InterfaceTypeAttribute)
-                        _ = attr;
-                }
-
-                if (guidAttr is not null
-                    && guidAttr.ConstructorArguments.Length == 1
-                    && guidAttr.ConstructorArguments[0].Value is string guidStr
-                    && Guid.TryParse(guidStr, out var result))
-                {
-                    guid = result;
-                }
-
-                // Assume interfaceType is IUnknown for now
-                if (guid is null)
-                {
-                    diagnostic = Diagnostic.Create(
-                        GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute,
-                        syntax.Identifier.GetLocation(),
-                        interfaceSymbol.ToDisplayString());
-                    return false;
-                }
-                diagnostic = null;
-                return true;
+                var attrDisplayString = attr.AttributeClass?.ToDisplayString();
+                if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute)
+                    guidAttr = attr;
+                else if (attrDisplayString is TypeNames.InterfaceTypeAttribute)
+                    _ = attr;
             }
 
-            public override int GetHashCode()
+            if (guidAttr is not null
+                && guidAttr.ConstructorArguments.Length == 1
+                && guidAttr.ConstructorArguments[0].Value is string guidStr
+                && Guid.TryParse(guidStr, out var result))
             {
-                // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode
-                return HashCode.Combine(Type, TypeDefinitionContext, InterfaceId);
+                guid = result;
             }
 
-            public bool Equals(ComInterfaceInfo other)
+            // Assume interfaceType is IUnknown for now
+            if (guid is null)
             {
-                // ContainingSyntax and ContainingSyntaxContext are not used in the hash code
-                return Type == other.Type
-                    && TypeDefinitionContext == other.TypeDefinitionContext
-                    && InterfaceId == other.InterfaceId;
+                diagnostic = Diagnostic.Create(
+                    GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute,
+                    syntax.Identifier.GetLocation(),
+                    interfaceSymbol.ToDisplayString());
+                return false;
             }
+            diagnostic = null;
+            return true;
+        }
+
+        public override int GetHashCode()
+        {
+            // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode
+            return HashCode.Combine(Type, TypeDefinitionContext, InterfaceId);
+        }
+
+        public bool Equals(ComInterfaceInfo other)
+        {
+            // ContainingSyntax and ContainingSyntaxContext are not used in the hash code
+            return Type == other.Type
+                && TypeDefinitionContext == other.TypeDefinitionContext
+                && InterfaceId == other.InterfaceId;
         }
     }
 }
index 31708cb..daf7c14 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System;
 using System.Collections.Generic;
 using System.Collections.Immutable;
 using System.Linq;
@@ -12,155 +13,196 @@ using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
 
 namespace Microsoft.Interop
 {
-    public sealed partial class ComInterfaceGenerator
+
+    /// <summary>
+    /// Represents a method, its declaring interface, and its index in the interface's vtable.
+    /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator
+    /// </summary>
+    internal sealed class ComMethodContext : IEquatable<ComMethodContext>
     {
         /// <summary>
-        /// Represents a method, its declaring interface, and its index in the interface's vtable.
-        /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator
+        /// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="IncrementalMethodStubGenerationContext"/> generated for it yet.
+        /// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="IncrementalMethodStubGenerationContext"/> requires an ISymbol
         /// </summary>
         /// <param name="OriginalDeclaringInterface">
         /// The interface that originally declared the method in user code
         /// </param>
-        /// <param name="OwningInterface">
-        /// The interface that this methods is being generated for (may be different that OriginalDeclaringInterface if it is an inherited method)
-        /// </param>
         /// <param name="MethodInfo">The basic information about the method.</param>
-        /// <param name="Index">The index on the interface vtable that points to this method</param>
-        /// <param name="GenerationContext"></param>
-        private sealed record ComMethodContext(
+        /// <param name="Index">The vtable index for the method.</param>
+        public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, ComMethodInfo MethodInfo, int Index);
+
+        /// <summary>
+        /// The fully-constructed immutable state for a <see cref="ComMethodContext"/>.
+        /// </summary>
+        private record struct State(
             ComInterfaceContext OriginalDeclaringInterface,
             ComInterfaceContext OwningInterface,
             ComMethodInfo MethodInfo,
-            int Index,
-            IncrementalMethodStubGenerationContext GenerationContext)
+            IncrementalMethodStubGenerationContext GenerationContext);
+
+        private readonly State _state;
+
+        /// <summary>
+        /// Construct a full method context from the <paramref name="builder"/>, context, and additional information.
+        /// </summary>
+        /// <param name="builder">The partially constructed context</param>
+        /// <param name="owningInterface">The final owning interface of this method context</param>
+        /// <param name="generationContext">The generation context for this method</param>
+        public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext generationContext)
         {
-            /// <summary>
-            /// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="IncrementalMethodStubGenerationContext"/> generated for it yet.
-            /// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="IncrementalMethodStubGenerationContext"/> requires an ISymbol
-            /// </summary>
-            public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, ComMethodInfo MethodInfo, int Index);
+            _state = new State(builder.OriginalDeclaringInterface, owningInterface, builder.MethodInfo, generationContext);
+        }
 
-            public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;
+        public override bool Equals(object obj) => obj is ComMethodContext other && Equals(other);
 
-            public GeneratedMethodContextBase GetManagedToUnmanagedStub()
-            {
-                if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
-                {
-                    return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
-                }
-                var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext);
-                return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
-            }
+        public override int GetHashCode() => _state.GetHashCode();
+
+        public bool Equals(ComMethodContext other) => _state.Equals(other);
+
+        public ComInterfaceContext OriginalDeclaringInterface => _state.OriginalDeclaringInterface;
+
+        public ComInterfaceContext OwningInterface => _state.OwningInterface;
+
+        public ComMethodInfo MethodInfo => _state.MethodInfo;
+
+        public IncrementalMethodStubGenerationContext GenerationContext => _state.GenerationContext;
 
-            public GeneratedMethodContextBase GetNativeToManagedStub()
+        public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;
+
+        private GeneratedMethodContextBase? _managedToUnmanagedStub;
+
+        public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub();
+
+        private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
+        {
+            if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
             {
-                if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
-                {
-                    return new SkippedStubContext(GenerationContext.OriginalDefiningType);
-                }
-                var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext);
-                return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
+                return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
             }
+            var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext);
+            return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
+        }
 
-            public MethodDeclarationSyntax GenerateUnreachableExceptionStub()
+        private GeneratedMethodContextBase? _unmanagedToManagedStub;
+
+        public GeneratedMethodContextBase UnmanagedToManagedStub => _unmanagedToManagedStub ??= CreateUnmanagedToManagedStub();
+
+        private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
+        {
+            if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
             {
-                // DeclarationCopiedFromBaseDeclaration(<Arguments>) => throw new UnreachableException("This method should not be reached");
-                return MethodInfo.Syntax
-                    .WithModifiers(TokenList())
-                    .WithAttributeLists(List<AttributeListSyntax>())
-                    .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(
-                        ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName)))
-                    .WithExpressionBody(ArrowExpressionClause(
-                        ThrowExpression(
-                            ObjectCreationExpression(
-                                ParseTypeName(TypeNames.UnreachableException))
-                                .WithArgumentList(ArgumentList()))));
+                return new SkippedStubContext(GenerationContext.OriginalDefiningType);
             }
+            var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext);
+            return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
+        }
+
+        private MethodDeclarationSyntax? _unreachableExceptionStub;
+
+        public MethodDeclarationSyntax UnreachableExceptionStub => _unreachableExceptionStub ??= CreateUnreachableExceptionStub();
+
+        private MethodDeclarationSyntax CreateUnreachableExceptionStub()
+        {
+            // DeclarationCopiedFromBaseDeclaration(<Arguments>) => throw new UnreachableException("This method should not be reached");
+            return MethodInfo.Syntax
+                .WithModifiers(TokenList())
+                .WithAttributeLists(List<AttributeListSyntax>())
+                .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(
+                    ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName)))
+                .WithExpressionBody(ArrowExpressionClause(
+                    ThrowExpression(
+                        ObjectCreationExpression(
+                            ParseTypeName(TypeNames.UnreachableException))
+                            .WithArgumentList(ArgumentList()))));
+        }
+
+        private MethodDeclarationSyntax? _shadow;
+
+        public MethodDeclarationSyntax Shadow => _shadow ??= GenerateShadow();
+
+        private MethodDeclarationSyntax GenerateShadow()
+        {
+            // DeclarationCopiedFromBaseDeclaration(<Arguments>)
+            // {
+            //    return ((<baseInterfaceType>)this).<MethodName>(<Arguments>);
+            // }
+            var forwarder = new Forwarder();
+            return MethodDeclaration(GenerationContext.SignatureContext.StubReturnType, MethodInfo.MethodName)
+                .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword)))
+                .WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters)))
+                .WithExpressionBody(
+                    ArrowExpressionClause(
+                        InvocationExpression(
+                            MemberAccessExpression(
+                                SyntaxKind.SimpleMemberAccessExpression,
+                                ParenthesizedExpression(
+                                    CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))),
+                                IdentifierName(MethodInfo.MethodName)),
+                            ArgumentList(
+                                SeparatedList(GenerationContext.SignatureContext.ManagedParameters.Select(p => forwarder.AsArgument(p, new ManagedStubCodeContext())))))))
+                .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+        }
 
-            public MethodDeclarationSyntax GenerateShadow()
+        /// <summary>
+        /// Returns a flat list of <see cref="Builder"/> and its owning interface that represents all declared methods and inherited methods.
+        /// Guarantees the output will be sorted by order of interface input order, then by vtable order.
+        /// </summary>
+        public static List<(ComInterfaceContext OwningInterface, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray<ComMethodInfo>)> ifaceAndDeclaredMethods, CancellationToken _)
+        {
+            // Optimization : This step technically only needs a single interface inheritance hierarchy.
+            // We can calculate all inheritance chains in a previous step and only pass a single inheritance chain to this method.
+            // This way, when a single method changes, we would only need to recalculate this for the inheritance chain in which that method exists.
+
+            var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2);
+            var allMethodsCache = new Dictionary<ComInterfaceContext, ImmutableArray<Builder>>();
+            var accumulator = new List<(ComInterfaceContext OwningInterface, Builder Method)>();
+            foreach (var kvp in ifaceAndDeclaredMethods)
             {
-                // DeclarationCopiedFromBaseDeclaration(<Arguments>)
-                // {
-                //    return ((<baseInterfaceType>)this).<MethodName>(<Arguments>);
-                // }
-                var forwarder = new Forwarder();
-                return MethodDeclaration(GenerationContext.SignatureContext.StubReturnType, MethodInfo.MethodName)
-                    .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword)))
-                    .WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters)))
-                    .WithExpressionBody(
-                        ArrowExpressionClause(
-                            InvocationExpression(
-                                MemberAccessExpression(
-                                    SyntaxKind.SimpleMemberAccessExpression,
-                                    ParenthesizedExpression(
-                                        CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))),
-                                    IdentifierName(MethodInfo.MethodName)),
-                                ArgumentList(
-                                    SeparatedList(GenerationContext.SignatureContext.ManagedParameters.Select(p => forwarder.AsArgument(p, new ManagedStubCodeContext())))))))
-                    .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+                var methods = AddMethods(kvp.Item1, kvp.Item2);
+                foreach (var method in methods)
+                {
+                    accumulator.Add((kvp.Item1, method));
+                }
             }
+            return accumulator;
 
             /// <summary>
-            /// Returns a flat list of <see cref="ComMethodContext.Builder"/> and it's type key owner that represents all declared methods, and inherited methods.
-            /// Guarantees the output will be sorted by order of interface input order, then by vtable order.
+            /// Adds methods to a cache and returns inherited and declared methods for the interface in vtable order
             /// </summary>
-            public static List<(ComInterfaceContext TypeKeyOwner, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray<ComMethodInfo>)> ifaceAndDeclaredMethods, CancellationToken _)
+            ImmutableArray<Builder> AddMethods(ComInterfaceContext iface, IEnumerable<ComMethodInfo> declaredMethods)
             {
-                // Optimization : This step technically only needs a single interface inheritance hierarchy.
-                // We can calculate all inheritance chains in a previous step and only pass a single inheritance chain to this method.
-                // This way, when a single method changes, we would only need to recalculate this for the inheritance chain in which that method exists.
-
-                var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2);
-                var allMethodsCache = new Dictionary<ComInterfaceContext, ImmutableArray<Builder>>();
-                var accumulator = new List<(ComInterfaceContext TypeKeyOwner, Builder Method)>();
-                foreach (var kvp in ifaceAndDeclaredMethods)
+                if (allMethodsCache.TryGetValue(iface, out var cachedValue))
                 {
-                    var methods = AddMethods(kvp.Item1, kvp.Item2);
-                    foreach (var method in methods)
-                    {
-                        accumulator.Add((kvp.Item1, method));
-                    }
+                    return cachedValue;
                 }
-                return accumulator;
 
-                /// <summary>
-                /// Adds methods to a cache and returns inherited and declared methods for the interface in vtable order
-                /// </summary>
-                ImmutableArray<Builder> AddMethods(ComInterfaceContext iface, IEnumerable<ComMethodInfo> declaredMethods)
+                int startingIndex = 3;
+                List<Builder> methods = new();
+                // If we have a base interface, we should add the inherited methods to our list in vtable order
+                if (iface.Base is not null)
                 {
-                    if (allMethodsCache.TryGetValue(iface, out var cachedValue))
-                    {
-                        return cachedValue;
-                    }
-
-                    int startingIndex = 3;
-                    List<Builder> methods = new();
-                    // If we have a base interface, we should add the inherited methods to our list in vtable order
-                    if (iface.Base is not null)
+                    var baseComIface = iface.Base;
+                    ImmutableArray<Builder> baseMethods;
+                    if (!allMethodsCache.TryGetValue(baseComIface, out var pair))
                     {
-                        var baseComIface = iface.Base;
-                        ImmutableArray<Builder> baseMethods;
-                        if (!allMethodsCache.TryGetValue(baseComIface, out var pair))
-                        {
-                            baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]);
-                        }
-                        else
-                        {
-                            baseMethods = pair;
-                        }
-                        methods.AddRange(baseMethods);
-                        startingIndex += baseMethods.Length;
+                        baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]);
                     }
-                    // Then we append the declared methods in vtable order
-                    foreach (var method in declaredMethods)
+                    else
                     {
-                        methods.Add(new Builder(iface, method, startingIndex++));
+                        baseMethods = pair;
                     }
-                    // Cache so we don't recalculate if many interfaces inherit from the same one
-                    var imm = methods.ToImmutableArray();
-                    allMethodsCache[iface] = imm;
-                    return imm;
+                    methods.AddRange(baseMethods);
+                    startingIndex += baseMethods.Length;
+                }
+                // Then we append the declared methods in vtable order
+                foreach (var method in declaredMethods)
+                {
+                    methods.Add(new Builder(iface, method, startingIndex++));
                 }
+                // Cache so we don't recalculate if many interfaces inherit from the same one
+                var imm = methods.ToImmutableArray();
+                allMethodsCache[iface] = imm;
+                return imm;
             }
         }
     }
index 38bf19a..a259573 100644 (file)
@@ -1,8 +1,6 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
-using System.Collections.Generic;
 using System.Collections.Immutable;
 using System.Diagnostics;
 using System.Linq;
@@ -13,108 +11,105 @@ using Microsoft.CodeAnalysis.CSharp.Syntax;
 
 namespace Microsoft.Interop
 {
-    public sealed partial class ComInterfaceGenerator
+    /// <summary>
+    /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax.
+    /// </summary>
+    internal sealed record ComMethodInfo(
+        MethodDeclarationSyntax Syntax,
+        string MethodName)
     {
         /// <summary>
-        /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax.
+        /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
         /// </summary>
-        private sealed record ComMethodInfo(
-            MethodDeclarationSyntax Syntax,
-            string MethodName)
+        public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, IMethodSymbol Symbol, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken ct)
         {
-            /// <summary>
-            /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
-            /// </summary>
-            public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, IMethodSymbol Symbol, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken ct)
+            var methods = ImmutableArray.CreateBuilder<(ComMethodInfo, IMethodSymbol, Diagnostic?)>();
+            foreach (var member in data.ifaceSymbol.GetMembers())
             {
-                var methods = ImmutableArray.CreateBuilder<(ComMethodInfo, IMethodSymbol, Diagnostic?)>();
-                foreach (var member in data.ifaceSymbol.GetMembers())
+                if (IsComMethodCandidate(member))
                 {
-                    if (IsComMethodCandidate(member))
-                    {
-                        methods.Add(CalculateMethodInfo(data.ifaceContext, (IMethodSymbol)member, ct));
-                    }
+                    methods.Add(CalculateMethodInfo(data.ifaceContext, (IMethodSymbol)member, ct));
                 }
-                return methods.ToImmutable().ToSequenceEqual();
             }
+            return methods.ToImmutable().ToSequenceEqual();
+        }
 
-            private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method)
+        private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method)
+        {
+            // Verify the method has no generic types or defined implementation
+            // and is not marked static or sealed
+            if (comMethodDeclaringSyntax.TypeParameterList is not null
+                || comMethodDeclaringSyntax.Body is not null
+                || comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.SealedKeyword))
             {
-                // Verify the method has no generic types or defined implementation
-                // and is not marked static or sealed
-                if (comMethodDeclaringSyntax.TypeParameterList is not null
-                    || comMethodDeclaringSyntax.Body is not null
-                    || comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.SealedKeyword))
-                {
-                    return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, comMethodDeclaringSyntax.Identifier.GetLocation(), method.Name);
-                }
-
-                // Verify the method does not have a ref return
-                if (method.ReturnsByRef || method.ReturnsByRefReadonly)
-                {
-                    return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, comMethodDeclaringSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString());
-                }
-
-                return null;
+                return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, comMethodDeclaringSyntax.Identifier.GetLocation(), method.Name);
             }
 
-            private static bool IsComMethodCandidate(ISymbol member)
+            // Verify the method does not have a ref return
+            if (method.ReturnsByRef || method.ReturnsByRefReadonly)
             {
-                return member.Kind == SymbolKind.Method && !member.IsStatic;
+                return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, comMethodDeclaringSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString());
             }
 
-            private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo(ComInterfaceInfo ifaceContext, IMethodSymbol method, CancellationToken ct)
-            {
-                ct.ThrowIfCancellationRequested();
-                Debug.Assert(IsComMethodCandidate(method));
+            return null;
+        }
 
-                // We only support methods that are defined in the same partial interface definition as the
-                // [GeneratedComInterface] attribute.
-                // This restriction not only makes finding the syntax for a given method cheaper,
-                // but it also enables us to ensure that we can determine vtable method order easily.
-                Location interfaceLocation = ifaceContext.Declaration.GetLocation();
-                Location? methodLocationInAttributedInterfaceDeclaration = null;
-                foreach (var methodLocation in method.Locations)
-                {
-                    if (methodLocation.SourceTree == interfaceLocation.SourceTree
-                        && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan))
-                    {
-                        methodLocationInAttributedInterfaceDeclaration = methodLocation;
-                        break;
-                    }
-                }
-                // TODO: this should cause a diagnostic
-                if (methodLocationInAttributedInterfaceDeclaration is null)
-                {
-                    return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString()));
-                }
+        private static bool IsComMethodCandidate(ISymbol member)
+        {
+            return member.Kind == SymbolKind.Method && !member.IsStatic;
+        }
 
+        private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo(ComInterfaceInfo ifaceContext, IMethodSymbol method, CancellationToken ct)
+        {
+            ct.ThrowIfCancellationRequested();
+            Debug.Assert(IsComMethodCandidate(method));
 
-                // Find the matching declaration syntax
-                MethodDeclarationSyntax? comMethodDeclaringSyntax = null;
-                foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences)
-                {
-                    var declaringSyntax = declaringSyntaxReference.GetSyntax(ct);
-                    Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration));
-                    if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan))
-                    {
-                        comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax;
-                        break;
-                    }
-                }
-                if (comMethodDeclaringSyntax is null)
+            // We only support methods that are defined in the same partial interface definition as the
+            // [GeneratedComInterface] attribute.
+            // This restriction not only makes finding the syntax for a given method cheaper,
+            // but it also enables us to ensure that we can determine vtable method order easily.
+            Location interfaceLocation = ifaceContext.Declaration.GetLocation();
+            Location? methodLocationInAttributedInterfaceDeclaration = null;
+            foreach (var methodLocation in method.Locations)
+            {
+                if (methodLocation.SourceTree == interfaceLocation.SourceTree
+                    && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan))
                 {
-                    return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString()));
+                    methodLocationInAttributedInterfaceDeclaration = methodLocation;
+                    break;
                 }
+            }
+
+            if (methodLocationInAttributedInterfaceDeclaration is null)
+            {
+                return (null, method, Diagnostic.Create(GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, method.Locations.FirstOrDefault(), method.ToDisplayString()));
+            }
 
-                var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method);
-                if (diag is not null)
+
+            // Find the matching declaration syntax
+            MethodDeclarationSyntax? comMethodDeclaringSyntax = null;
+            foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences)
+            {
+                var declaringSyntax = declaringSyntaxReference.GetSyntax(ct);
+                Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration));
+                if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan))
                 {
-                    return (null, method, diag);
+                    comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax;
+                    break;
                 }
-                var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name);
-                return (comMethodInfo, method, null);
             }
+            if (comMethodDeclaringSyntax is null)
+            {
+                return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString()));
+            }
+
+            var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method);
+            if (diag is not null)
+            {
+                return (null, method, diag);
+            }
+            var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name);
+            return (comMethodInfo, method, null);
         }
     }
 }
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs
new file mode 100644 (file)
index 0000000..6f0966e
--- /dev/null
@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace Microsoft.Interop
+{
+    internal sealed record GeneratedStubCodeContext(
+        ManagedTypeInfo OriginalDefiningType,
+        ContainingSyntaxContext ContainingSyntaxContext,
+        SyntaxEquivalentNode<MethodDeclarationSyntax> Stub,
+        SequenceEqualImmutableArray<Diagnostic> Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics);
+}
index fb0dd80..826aa4e 100644 (file)
@@ -2,9 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Collections.Immutable;
-using System.Text;
 using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs
deleted file mode 100644 (file)
index cb55e36..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-
-namespace Microsoft.Interop
-{
-    internal static class InlinedTypes
-    {
-        /// <summary>
-        /// Returns the ClassDeclarationSyntax for:
-        /// <code>
-        /// public sealed unsafe class ComWrappersUnwrapper : IUnmanagedObjectUnwrapper
-        /// {
-        ///     public static object GetObjectForUnmanagedWrapper(void* ptr)
-        ///     {
-        ///         return ComWrappers.ComInterfaceDispatch.GetInstance<object>((ComWrappers.ComInterfaceDispatch*)ptr);
-        ///     }
-        /// }
-        /// </code>
-        /// </summary>
-        public static ClassDeclarationSyntax ComWrappersUnwrapper { get; } = GetComWrappersUnwrapper();
-
-        public static ClassDeclarationSyntax GetComWrappersUnwrapper()
-        {
-            return ClassDeclaration("ComWrappersUnwrapper")
-                .AddModifiers(Token(SyntaxKind.SealedKeyword),
-                              Token(SyntaxKind.UnsafeKeyword),
-                              Token(SyntaxKind.StaticKeyword),
-                              Token(SyntaxKind.FileKeyword))
-                .AddMembers(
-                    MethodDeclaration(
-                        PredefinedType(Token(SyntaxKind.ObjectKeyword)),
-                        Identifier("GetComObjectForUnmanagedWrapper"))
-                    .AddModifiers(Token(SyntaxKind.PublicKeyword),
-                                  Token(SyntaxKind.StaticKeyword))
-                    .AddParameterListParameters(
-                        Parameter(Identifier("ptr"))
-                            .WithType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))))
-                    .WithBody(body: Body()));
-
-            static BlockSyntax Body()
-            {
-                var invocation = InvocationExpression(
-                                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
-                                        MemberAccessExpression(
-                                            SyntaxKind.SimpleMemberAccessExpression,
-                                            IdentifierName("ComWrappers"),
-                                            IdentifierName("ComInterfaceDispatch")),
-                                        GenericName(
-                                            Identifier("GetInstance"),
-                                            TypeArgumentList(
-                                                SeparatedList<SyntaxNode>(
-                                                    new[] { PredefinedType(Token(SyntaxKind.ObjectKeyword)) })))))
-                                .AddArgumentListArguments(
-                                    Argument(
-                                        null,
-                                        Token(SyntaxKind.None),
-                                        CastExpression(
-                                            PointerType(
-                                                QualifiedName(
-                                                    IdentifierName("ComWrappers"),
-                                                    IdentifierName("ComInterfaceDispatch"))),
-                                            IdentifierName("ptr"))));
-
-                return Block(ReturnStatement(invocation));
-            }
-        }
-
-        /// <summary>
-        /// <code>
-        /// file static class UnmanagedObjectUnwrapper
-        /// {
-        ///     public static object GetObjectForUnmanagedWrapper<T>(void* ptr) where T : IUnmanagedObjectUnwrapper
-        ///     {
-        ///         return T.GetObjectForUnmanagedWrapper(ptr);
-        ///     }
-        /// }
-        /// </code>
-        /// </summary>
-        public static ClassDeclarationSyntax UnmanagedObjectUnwrapper { get; } = GetUnmanagedObjectUnwrapper();
-
-        private static ClassDeclarationSyntax GetUnmanagedObjectUnwrapper()
-        {
-            const string tUnwrapper = "TUnwrapper";
-            return ClassDeclaration("UnmanagedObjectUnwrapper")
-                  .AddModifiers(Token(SyntaxKind.FileKeyword),
-                                Token(SyntaxKind.StaticKeyword))
-                  .AddMembers(
-                      MethodDeclaration(
-                          PredefinedType(Token(SyntaxKind.ObjectKeyword)),
-                          Identifier("GetObjectForUnmanagedWrapper"))
-                      .AddModifiers(Token(SyntaxKind.PublicKeyword),
-                                    Token(SyntaxKind.StaticKeyword))
-                      .AddTypeParameterListParameters(
-                          TypeParameter(Identifier(tUnwrapper)))
-                      .AddParameterListParameters(
-                          Parameter(Identifier("ptr"))
-                              .WithType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))))
-                      .AddConstraintClauses(TypeParameterConstraintClause(IdentifierName(tUnwrapper))
-                           .AddConstraints(TypeConstraint(ParseTypeName(TypeNames.IUnmanagedObjectUnwrapper))))
-                      .WithBody(body: Body()));
-
-            static BlockSyntax Body()
-            {
-                var invocation = InvocationExpression(
-                                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
-                                        IdentifierName("T"),
-                                        IdentifierName("GetObjectForUnmanagedWrapper")))
-                                .AddArgumentListArguments(
-                                    Argument(
-                                        null,
-                                        Token(SyntaxKind.None),
-                                        IdentifierName("ptr")));
-
-                return Block(ReturnStatement(invocation));
-            }
-
-        }
-    }
-}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs
new file mode 100644 (file)
index 0000000..aa9ef04
--- /dev/null
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Immutable;
+using Microsoft.CodeAnalysis;
+
+namespace Microsoft.Interop
+{
+    internal sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray<Diagnostic>.Empty));
+}
index 4d62b5d..7acc249 100644 (file)
@@ -4,8 +4,6 @@
 using System;
 using System.Collections.Generic;
 using System.Collections.Immutable;
-using System.Diagnostics;
-using System.Linq;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
index 35f610f..e813c8f 100644 (file)
@@ -5,7 +5,6 @@ using System;
 using System.Collections.Generic;
 using System.Collections.Immutable;
 using System.Linq;
-using System.Text;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using Microsoft.CodeAnalysis.CSharp;
 using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
index 1d161ff..43600ca 100644 (file)
@@ -2,12 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Collections.Immutable;
 using System.Diagnostics;
 using System.Linq;
 using System.Threading;
-using System.Xml.Linq;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
index a421cd2..204c63f 100644 (file)
@@ -2,9 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Linq;
-using System.Text;
 using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop