Implement byval-out marshalling for unmanaged->managed (#86666)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Wed, 24 May 2023 20:43:37 +0000 (13:43 -0700)
committerGitHub <noreply@github.com>
Wed, 24 May 2023 20:43:37 +0000 (13:43 -0700)
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs

index ee98097..990afd2 100644 (file)
@@ -248,7 +248,7 @@ namespace Microsoft.Interop
 
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+                    marshallingStrategy = new UnmanagedToManagedOwnershipTrackingStrategy(marshallingStrategy);
                 }
 
                 if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
@@ -258,7 +258,7 @@ namespace Microsoft.Interop
 
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+                    marshallingStrategy = new CleanupOwnedOriginalValueMarshalling(marshallingStrategy);
                 }
             }
 
@@ -327,18 +327,18 @@ namespace Microsoft.Interop
 
                 FreeStrategy freeStrategy = GetFreeStrategy(info, context);
                 IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource();
-                IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
+                ElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
 
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+                    marshallingStrategy = new UnmanagedToManagedOwnershipTrackingStrategy(marshallingStrategy);
                 }
 
                 marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree);
 
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+                    marshallingStrategy = new CleanupOwnedOriginalValueMarshalling(marshallingStrategy);
                 }
 
                 if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
@@ -355,10 +355,10 @@ namespace Microsoft.Interop
                 IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax);
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+                    marshallingStrategy = new UnmanagedToManagedOwnershipTrackingStrategy(marshallingStrategy);
                 }
 
-                IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
+                ElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
 
                 marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree);
 
@@ -378,7 +378,7 @@ namespace Microsoft.Interop
 
                 if (freeStrategy == FreeStrategy.FreeOriginal)
                 {
-                    marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+                    marshallingStrategy = new CleanupOwnedOriginalValueMarshalling(marshallingStrategy);
                 }
             }
 
@@ -437,9 +437,9 @@ namespace Microsoft.Interop
             return FreeStrategy.NoFree;
         }
 
-        private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
+        private static ElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
         {
-            IElementsMarshalling elementsMarshalling;
+            ElementsMarshalling elementsMarshalling;
 
             bool elementIsBlittable = elementMarshaller is BlittableMarshaller;
             if (elementIsBlittable)
index d5b2da8..ee5eef2 100644 (file)
@@ -53,7 +53,8 @@ namespace Microsoft.Interop
                 case StubCodeContext.Stage.Setup:
                     return _nativeTypeMarshaller.GenerateSetupStatements(info, context);
                 case StubCodeContext.Stage.Marshal:
-                    if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
+                    if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional
+                        || (context.Direction == MarshalDirection.UnmanagedToManaged && ShouldGenerateByValueOutMarshalling(info)))
                     {
                         return _nativeTypeMarshaller.GenerateMarshalStatements(info, context);
                     }
@@ -84,14 +85,14 @@ namespace Microsoft.Interop
                     break;
                 case StubCodeContext.Stage.Unmarshal:
                     if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional
-                        || (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)))
+                        || (context.Direction == MarshalDirection.ManagedToUnmanaged && ShouldGenerateByValueOutMarshalling(info)))
                     {
                         return _nativeTypeMarshaller.GenerateUnmarshalStatements(info, context);
                     }
                     break;
                 case StubCodeContext.Stage.GuaranteedUnmarshal:
                     if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional
-                        || (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)))
+                        || (context.Direction == MarshalDirection.ManagedToUnmanaged && ShouldGenerateByValueOutMarshalling(info)))
                     {
                         return _nativeTypeMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
                     }
@@ -105,6 +106,11 @@ namespace Microsoft.Interop
             return Array.Empty<StatementSyntax>();
         }
 
+        private bool ShouldGenerateByValueOutMarshalling(TypePositionInfo info)
+        {
+            return _enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out);
+        }
+
         public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context)
         {
             return _enableByValueContentsMarshalling;
index 3c1468f..c3e5a82 100644 (file)
@@ -13,69 +13,155 @@ namespace Microsoft.Interop
 {
     internal interface IElementsMarshallingCollectionSource
     {
-        StatementSyntax GetManagedValuesNumElementsAssignment(TypePositionInfo info, StubCodeContext context);
         InvocationExpressionSyntax GetUnmanagedValuesDestination(TypePositionInfo info, StubCodeContext context);
         InvocationExpressionSyntax GetManagedValuesSource(TypePositionInfo info, StubCodeContext context);
         InvocationExpressionSyntax GetUnmanagedValuesSource(TypePositionInfo info, StubCodeContext context);
         InvocationExpressionSyntax GetManagedValuesDestination(TypePositionInfo info, StubCodeContext context);
     }
 
-    internal interface IElementsMarshalling
+    internal abstract class ElementsMarshalling
     {
-        StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
-        StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context);
-        StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
-        StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
-        StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context);
+        protected IElementsMarshallingCollectionSource CollectionSource { get; }
+
+        protected ElementsMarshalling(IElementsMarshallingCollectionSource collectionSource)
+        {
+            CollectionSource = collectionSource;
+        }
+
+        public StatementSyntax GenerateClearManagedSource(TypePositionInfo info, StubCodeContext context)
+        {
+            // <GetUnmanagedValuesDestination>.Clear();
+            return ExpressionStatement(
+                InvocationExpression(
+                    MemberAccessExpression(
+                        SyntaxKind.SimpleMemberAccessExpression,
+                        CollectionSource.GetUnmanagedValuesDestination(info, context),
+                        IdentifierName("Clear"))));
+        }
+        public StatementSyntax GenerateClearUnmanagedValuesSource(TypePositionInfo info, StubCodeContext context)
+        {
+            // <GetUnmanagedValuesSource>.Clear();
+            return ExpressionStatement(
+                InvocationExpression(
+                    MemberAccessExpression(
+                        SyntaxKind.SimpleMemberAccessExpression,
+                        CollectionSource.GetUnmanagedValuesSource(info, context),
+                        IdentifierName("Clear"))));
+        }
+
+        public abstract StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
+        public abstract StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context);
+        public abstract StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
+
+        public abstract StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
+        public abstract StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context);
+    }
+
+#pragma warning disable SA1400 // Access modifier should be declared https://github.com/DotNetAnalyzers/StyleCopAnalyzers/issues/3659
+    static file class ElementsMarshallingCollectionSourceExtensions
+#pragma warning restore SA1400 // Access modifier should be declared
+    {
+        public static StatementSyntax GetNumElementsAssignmentFromManagedValuesSource(this IElementsMarshallingCollectionSource source, TypePositionInfo info, StubCodeContext context)
+        {
+            var numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            // <numElements> = <GetManagedValuesSource>.Length;
+            return ExpressionStatement(
+                AssignmentExpression(
+                    SyntaxKind.SimpleAssignmentExpression,
+                    IdentifierName(numElementsIdentifier),
+                    MemberAccessExpression(
+                        SyntaxKind.SimpleMemberAccessExpression,
+                        source.GetManagedValuesSource(info, context),
+                        IdentifierName("Length"))));
+        }
+
+        public static StatementSyntax GetNumElementsAssignmentFromManagedValuesDestination(this IElementsMarshallingCollectionSource source, TypePositionInfo info, StubCodeContext context)
+        {
+            var numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            // <numElements> = <GetManagedValuesDestination>.Length;
+            return ExpressionStatement(
+                AssignmentExpression(
+                    SyntaxKind.SimpleAssignmentExpression,
+                    IdentifierName(numElementsIdentifier),
+                    MemberAccessExpression(
+                        SyntaxKind.SimpleMemberAccessExpression,
+                        source.GetManagedValuesDestination(info, context),
+                        IdentifierName("Length"))));
+        }
     }
 
     /// <summary>
     /// Support for marshalling blittable elements
     /// </summary>
-    internal sealed class BlittableElementsMarshalling : IElementsMarshalling
+    internal sealed class BlittableElementsMarshalling : ElementsMarshalling
     {
         private readonly TypeSyntax _managedElementType;
         private readonly TypeSyntax _unmanagedElementType;
-        private readonly IElementsMarshallingCollectionSource _collectionSource;
 
         public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
+            :base(collectionSource)
         {
             _managedElementType = managedElementType;
             _unmanagedElementType = unmanagedElementType;
-            _collectionSource = collectionSource;
         }
 
-        public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
-            // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
-            // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
-            // <GetUnmanagedValuesDestination>.Clear();
+            ExpressionSyntax destination = CastToManagedIfNecessary(CollectionSource.GetUnmanagedValuesSource(info, context));
+
+            // MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(<GetManagedValuesSource>), <GetManagedValuesSource>.Length)
+            ExpressionSyntax source = InvocationExpression(
+                MemberAccessExpression(
+                    SyntaxKind.SimpleMemberAccessExpression,
+                    ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
+                    IdentifierName("CreateSpan")),
+                ArgumentList(
+                    SeparatedList(new[]
+                    {
+                        Argument(
+                            InvocationExpression(
+                                MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                    ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
+                                    IdentifierName("GetReference")),
+                                ArgumentList(SingletonSeparatedList(
+                                    Argument(CollectionSource.GetManagedValuesDestination(info, context))))))
+                            .WithRefKindKeyword(
+                                Token(SyntaxKind.RefKeyword)),
+                        Argument(
+                            MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                CollectionSource.GetManagedValuesDestination(info, context),
+                                IdentifierName("Length")))
+                    })));
+
+            // <source>.CopyTo(<destination>);
             return ExpressionStatement(
                 InvocationExpression(
                     MemberAccessExpression(
                         SyntaxKind.SimpleMemberAccessExpression,
-                        _collectionSource.GetUnmanagedValuesDestination(info, context),
-                        IdentifierName("Clear"))));
+                        source,
+                        IdentifierName("CopyTo")))
+                .AddArgumentListArguments(
+                    Argument(destination)));
         }
 
-        public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
-            ExpressionSyntax destination = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context));
+            ExpressionSyntax destination = CastToManagedIfNecessary(CollectionSource.GetUnmanagedValuesDestination(info, context));
 
             // <GetManagedValuesSource>.CopyTo(<destination>);
             return ExpressionStatement(
                 InvocationExpression(
                     MemberAccessExpression(
                         SyntaxKind.SimpleMemberAccessExpression,
-                        _collectionSource.GetManagedValuesSource(info, context),
+                        CollectionSource.GetManagedValuesSource(info, context),
                         IdentifierName("CopyTo")))
                 .AddArgumentListArguments(
                     Argument(destination)));
         }
 
-        public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
-            ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context));
+            ExpressionSyntax source = CastToManagedIfNecessary(CollectionSource.GetUnmanagedValuesDestination(info, context));
 
             // MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(<GetManagedValuesSource>), <GetManagedValuesSource>.Length)
             ExpressionSyntax destination = InvocationExpression(
@@ -92,12 +178,12 @@ namespace Microsoft.Interop
                                     ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
                                     IdentifierName("GetReference")),
                                 ArgumentList(SingletonSeparatedList(
-                                    Argument(_collectionSource.GetManagedValuesSource(info, context))))))
+                                    Argument(CollectionSource.GetManagedValuesSource(info, context))))))
                             .WithRefKindKeyword(
                                 Token(SyntaxKind.RefKeyword)),
                         Argument(
                             MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
-                                _collectionSource.GetManagedValuesSource(info, context),
+                                CollectionSource.GetManagedValuesSource(info, context),
                                 IdentifierName("Length")))
                     })));
 
@@ -112,9 +198,9 @@ namespace Microsoft.Interop
                     Argument(destination)));
         }
 
-        public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
-            ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesSource(info, context));
+            ExpressionSyntax source = CastToManagedIfNecessary(CollectionSource.GetUnmanagedValuesSource(info, context));
 
             // <source>.CopyTo(<GetManagedValuesDestination>);
             return ExpressionStatement(
@@ -124,7 +210,7 @@ namespace Microsoft.Interop
                         source,
                         IdentifierName("CopyTo")))
                 .AddArgumentListArguments(
-                    Argument(_collectionSource.GetManagedValuesDestination(info, context))));
+                    Argument(CollectionSource.GetManagedValuesDestination(info, context))));
         }
 
         private ExpressionSyntax CastToManagedIfNecessary(ExpressionSyntax expression)
@@ -150,45 +236,31 @@ namespace Microsoft.Interop
                     Argument(expression))));
         }
 
-        public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context) => EmptyStatement();
+        public override StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context) => EmptyStatement();
     }
 
     /// <summary>
     /// Support for marshalling non-blittable elements
     /// </summary>
-    internal sealed class NonBlittableElementsMarshalling : IElementsMarshalling
+    internal sealed class NonBlittableElementsMarshalling : ElementsMarshalling
     {
         private readonly TypeSyntax _unmanagedElementType;
         private readonly IMarshallingGenerator _elementMarshaller;
         private readonly TypePositionInfo _elementInfo;
-        private readonly IElementsMarshallingCollectionSource _collectionSource;
 
         public NonBlittableElementsMarshalling(
             TypeSyntax unmanagedElementType,
             IMarshallingGenerator elementMarshaller,
             TypePositionInfo elementInfo,
             IElementsMarshallingCollectionSource collectionSource)
+            :base(collectionSource)
         {
             _unmanagedElementType = unmanagedElementType;
             _elementMarshaller = elementMarshaller;
             _elementInfo = elementInfo;
-            _collectionSource = collectionSource;
         }
 
-        public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
-        {
-            // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
-            // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
-            // <GetUnmanagedValuesDestination>.Clear();
-            return ExpressionStatement(
-                InvocationExpression(
-                    MemberAccessExpression(
-                        SyntaxKind.SimpleMemberAccessExpression,
-                        _collectionSource.GetUnmanagedValuesDestination(info, context),
-                        IdentifierName("Clear"))));
-        }
-
-        public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
             string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context);
             string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
@@ -204,7 +276,7 @@ namespace Microsoft.Interop
                     SingletonSeparatedList(
                         VariableDeclarator(Identifier(managedSpanIdentifier))
                         .WithInitializer(EqualsValueClause(
-                            _collectionSource.GetManagedValuesSource(info, context)))))),
+                            CollectionSource.GetManagedValuesSource(info, context)))))),
                 LocalDeclarationStatement(VariableDeclaration(
                     GenericName(
                         Identifier(TypeNames.System_Span),
@@ -213,17 +285,17 @@ namespace Microsoft.Interop
                         VariableDeclarator(
                             Identifier(nativeSpanIdentifier))
                         .WithInitializer(EqualsValueClause(
-                            _collectionSource.GetUnmanagedValuesDestination(info, context)))))),
+                            CollectionSource.GetUnmanagedValuesDestination(info, context)))))),
                 GenerateContentsMarshallingStatement(
                     info,
                     context,
                     MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                         IdentifierName(MarshallerHelpers.GetManagedSpanIdentifier(info, context)),
                         IdentifierName("Length")),
-                    StubCodeContext.Stage.Marshal));
+                    _elementInfo, _elementMarshaller, StubCodeContext.Stage.Marshal));
         }
 
-        public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
             string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context);
             string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
@@ -241,7 +313,7 @@ namespace Microsoft.Interop
                         VariableDeclarator(
                             Identifier(nativeSpanIdentifier))
                         .WithInitializer(EqualsValueClause(
-                            _collectionSource.GetUnmanagedValuesSource(info, context)))))),
+                            CollectionSource.GetUnmanagedValuesSource(info, context)))))),
                 LocalDeclarationStatement(VariableDeclaration(
                     GenericName(
                         Identifier(TypeNames.System_Span),
@@ -250,16 +322,16 @@ namespace Microsoft.Interop
                         VariableDeclarator(
                             Identifier(managedSpanIdentifier))
                         .WithInitializer(EqualsValueClause(
-                            _collectionSource.GetManagedValuesDestination(info, context)))))),
+                            CollectionSource.GetManagedValuesDestination(info, context)))))),
                 GenerateContentsMarshallingStatement(
                     info,
                     context,
                     IdentifierName(numElementsIdentifier),
-                    StubCodeContext.Stage.UnmarshalCapture,
+                    _elementInfo, _elementMarshaller, StubCodeContext.Stage.UnmarshalCapture,
                     StubCodeContext.Stage.Unmarshal));
         }
 
-        public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
         {
             // Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents,
             // not the array itself.
@@ -268,7 +340,7 @@ namespace Microsoft.Interop
             string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
             string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context);
 
-            var setNumElements = _collectionSource.GetManagedValuesNumElementsAssignment(info, context);
+            var setNumElements = CollectionSource.GetNumElementsAssignmentFromManagedValuesSource(info, context);
 
             // Span<TElement> <managedSpan> = MemoryMarshal.CreateSpan(ref Unsafe.AsRef(in <GetManagedValuesSource>.GetPinnableReference(), <numElements>));
             LocalDeclarationStatementSyntax managedValuesDeclaration = LocalDeclarationStatement(VariableDeclaration(
@@ -298,7 +370,7 @@ namespace Microsoft.Interop
                                                     InvocationExpression(
                                                         MemberAccessExpression(
                                                             SyntaxKind.SimpleMemberAccessExpression,
-                                                            _collectionSource.GetManagedValuesSource(info, context),
+                                                            CollectionSource.GetManagedValuesSource(info, context),
                                                             IdentifierName("GetPinnableReference")),
                                                             ArgumentList()))
                                                 .WithRefKindKeyword(
@@ -319,7 +391,7 @@ namespace Microsoft.Interop
                     VariableDeclarator(
                         Identifier(nativeSpanIdentifier))
                     .WithInitializer(EqualsValueClause(
-                        _collectionSource.GetUnmanagedValuesDestination(info, context))))));
+                        CollectionSource.GetUnmanagedValuesDestination(info, context))))));
 
             return Block(
                 setNumElements,
@@ -329,18 +401,18 @@ namespace Microsoft.Interop
                     info,
                     context,
                     IdentifierName(numElementsIdentifier),
-                    StubCodeContext.Stage.UnmarshalCapture,
+                    _elementInfo, _elementMarshaller, StubCodeContext.Stage.UnmarshalCapture,
                     StubCodeContext.Stage.Unmarshal));
         }
 
-        public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context)
+        public override StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context)
         {
             string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
             StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context,
                     MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                         IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)),
                         IdentifierName("Length")),
-                        StubCodeContext.Stage.Cleanup);
+                        _elementInfo, _elementMarshaller, StubCodeContext.Stage.Cleanup);
 
             if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement))
             {
@@ -350,22 +422,100 @@ namespace Microsoft.Interop
             return Block(
                 LocalDeclarationStatement(VariableDeclaration(
                 GenericName(
-                    Identifier(TypeNames.System_Span),
+                    Identifier(TypeNames.System_ReadOnlySpan),
                     TypeArgumentList(SingletonSeparatedList(_unmanagedElementType))),
                 SingletonSeparatedList(
                     VariableDeclarator(
                         Identifier(nativeSpanIdentifier))
                     .WithInitializer(EqualsValueClause(
                             context.Direction == MarshalDirection.ManagedToUnmanaged
-                                ? _collectionSource.GetUnmanagedValuesDestination(info, context)
-                                : _collectionSource.GetUnmanagedValuesSource(info, context)))))),
+                                ? CollectionSource.GetUnmanagedValuesDestination(info, context)
+                                : CollectionSource.GetUnmanagedValuesSource(info, context)))))),
                 contentsCleanupStatements);
         }
 
-        private StatementSyntax GenerateContentsMarshallingStatement(
+        public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
+        {
+            // Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents,
+            // not the array itself.
+            // This code is ugly since we're now enforcing readonly safety with ReadOnlySpan for all other scenarios,
+            // but this is an uncommon case so we don't want to design the API around enabling just it.
+            string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context);
+            string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
+
+            var setNumElements = CollectionSource.GetNumElementsAssignmentFromManagedValuesDestination(info, context);
+
+            // Span<TUnmanagedElement> <nativeSpan> = MemoryMarshal.CreateSpan(ref Unsafe.AsRef(in <GetUnmanagedValuesSource>.GetPinnableReference(), <numElements>));
+            LocalDeclarationStatementSyntax unmanagedValuesSource = LocalDeclarationStatement(VariableDeclaration(
+                GenericName(
+                    Identifier(TypeNames.System_Span),
+                    TypeArgumentList(
+                        SingletonSeparatedList(_unmanagedElementType))
+                ),
+                SingletonSeparatedList(VariableDeclarator(nativeSpanIdentifier).WithInitializer(EqualsValueClause(
+                    InvocationExpression(
+                        MemberAccessExpression(
+                            SyntaxKind.SimpleMemberAccessExpression,
+                            ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
+                            IdentifierName("CreateSpan")))
+                    .WithArgumentList(
+                        ArgumentList(
+                            SeparatedList(
+                                new[]
+                                {
+                                    Argument(
+                                        InvocationExpression(
+                                            MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                                ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
+                                                IdentifierName("AsRef")),
+                                            ArgumentList(SingletonSeparatedList(
+                                                Argument(
+                                                    InvocationExpression(
+                                                        MemberAccessExpression(
+                                                            SyntaxKind.SimpleMemberAccessExpression,
+                                                            CollectionSource.GetUnmanagedValuesSource(info, context),
+                                                            IdentifierName("GetPinnableReference")),
+                                                            ArgumentList()))
+                                                .WithRefKindKeyword(
+                                                    Token(SyntaxKind.InKeyword))))))
+                                    .WithRefKindKeyword(
+                                        Token(SyntaxKind.RefKeyword)),
+                                    Argument(
+                                        IdentifierName(numElementsIdentifier))
+                                }))))))));
+
+            // Span<TElement> <managedSpan> = <GetManagedValuesDestination>
+            LocalDeclarationStatementSyntax managedValuesDestination = LocalDeclarationStatement(VariableDeclaration(
+                GenericName(
+                    Identifier(TypeNames.System_Span),
+                    TypeArgumentList(SingletonSeparatedList(_elementInfo.ManagedType.Syntax))),
+                SingletonSeparatedList(
+                    VariableDeclarator(
+                        Identifier(managedSpanIdentifier))
+                    .WithInitializer(EqualsValueClause(
+                        CollectionSource.GetManagedValuesDestination(info, context))))));
+
+            return Block(
+                setNumElements,
+                unmanagedValuesSource,
+                managedValuesDestination,
+                GenerateContentsMarshallingStatement(
+                    info,
+                    context,
+                    IdentifierName(numElementsIdentifier),
+                    _elementInfo,
+                    new FreeAlwaysOwnedOriginalValueGenerator(_elementMarshaller),
+                    StubCodeContext.Stage.Marshal,
+                    StubCodeContext.Stage.PinnedMarshal,
+                    StubCodeContext.Stage.Cleanup));
+        }
+        private static StatementSyntax GenerateContentsMarshallingStatement(
             TypePositionInfo info,
             StubCodeContext context,
             ExpressionSyntax lengthExpression,
+            TypePositionInfo elementInfo,
+            IMarshallingGenerator elementMarshaller,
             params StubCodeContext.Stage[] stagesToGeneratePerElement)
         {
             string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context);
@@ -376,7 +526,7 @@ namespace Microsoft.Interop
                 nativeSpanIdentifier,
                 context);
 
-            TypePositionInfo localElementInfo = _elementInfo with
+            TypePositionInfo localElementInfo = elementInfo with
             {
                 InstanceIdentifier = info.InstanceIdentifier,
                 RefKind = info.IsByRef ? info.RefKind : info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind(),
@@ -388,16 +538,16 @@ namespace Microsoft.Interop
             foreach (StubCodeContext.Stage stage in stagesToGeneratePerElement)
             {
                 var elementSubContext = elementSetupSubContext with { CurrentStage = stage };
-                elementStatements.AddRange(_elementMarshaller.Generate(localElementInfo, elementSubContext));
+                elementStatements.AddRange(elementMarshaller.Generate(localElementInfo, elementSubContext));
             }
 
             if (elementStatements.Count != 0)
             {
                 StatementSyntax marshallingStatement = Block(
-                    List(_elementMarshaller.Generate(localElementInfo, elementSetupSubContext)
+                    List(elementMarshaller.Generate(localElementInfo, elementSetupSubContext)
                         .Concat(elementStatements)));
 
-                if (_elementMarshaller.AsNativeType(_elementInfo).Syntax is PointerTypeSyntax elementNativeType)
+                if (elementMarshaller.AsNativeType(elementInfo).Syntax is PointerTypeSyntax elementNativeType)
                 {
                     PointerNativeTypeAssignmentRewriter rewriter = new(elementSetupSubContext.GetIdentifiers(localElementInfo).native, elementNativeType);
                     marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement);
index 2163c71..ea798eb 100644 (file)
@@ -345,22 +345,6 @@ namespace Microsoft.Interop
                 ArgumentList(SingletonSeparatedList(
                     Argument(IdentifierName(numElementsIdentifier)))));
         }
-
-        public StatementSyntax GetManagedValuesNumElementsAssignment(TypePositionInfo info, StubCodeContext context)
-        {
-            string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
-            // int <numElements> = <GetManagedValuesSource>.Length;
-            return LocalDeclarationStatement(
-                VariableDeclaration(
-                    PredefinedType(Token(SyntaxKind.IntKeyword)),
-                    SingletonSeparatedList(
-                        VariableDeclarator(numElementsIdentifier)
-                            .WithInitializer(EqualsValueClause(
-                                MemberAccessExpression(
-                                    SyntaxKind.SimpleMemberAccessExpression,
-                                    GetManagedValuesSource(info, context),
-                                    IdentifierName("Length")))))));
-        }
     }
 
     /// <summary>
@@ -371,14 +355,14 @@ namespace Microsoft.Interop
         private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
         private readonly MarshallerShape _shape;
         private readonly ExpressionSyntax _numElementsExpression;
-        private readonly IElementsMarshalling _elementsMarshalling;
+        private readonly ElementsMarshalling _elementsMarshalling;
         private readonly bool _cleanupElements;
 
         public StatefulLinearCollectionMarshalling(
             ICustomTypeMarshallingStrategy innerMarshaller,
             MarshallerShape shape,
             ExpressionSyntax numElementsExpression,
-            IElementsMarshalling elementsMarshalling,
+            ElementsMarshalling elementsMarshalling,
             bool cleanupElements)
         {
             _innerMarshaller = innerMarshaller;
@@ -407,9 +391,6 @@ namespace Microsoft.Interop
 
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
         {
-            if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
-                yield break;
-
             foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context))
             {
                 yield return statement;
@@ -417,9 +398,21 @@ namespace Microsoft.Interop
 
             if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
             {
-                yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
+                // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
+                // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
+                yield return _elementsMarshalling.GenerateClearManagedSource(info, context);
                 yield break;
             }
+            if (context.Direction == MarshalDirection.UnmanagedToManaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
+            {
+                // If the parameter is marshalled by-value [Out] or [In, Out], then we need to unmarshal the contents of the collection
+                // into the passed-in collection value.
+                yield return _elementsMarshalling.GenerateUnmanagedToManagedByValueOutMarshalStatement(info, context);
+                yield break;
+            }
+
+            if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
+                yield break;
 
             yield return _elementsMarshalling.GenerateMarshalStatement(info, context);
         }
@@ -427,35 +420,61 @@ namespace Microsoft.Interop
         public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
         public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
         public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
-        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
         {
+            foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context))
+            {
+                yield return statement;
+            }
+
             string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            yield return LocalDeclarationStatement(
+                VariableDeclaration(
+                    PredefinedType(Token(SyntaxKind.IntKeyword)),
+                    SingletonSeparatedList(
+                        VariableDeclarator(numElementsIdentifier))));
+            // Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
+            // The value will never be used unless it has been initialized, so this is safe.
+            yield return MarshallerHelpers.SkipInitOrDefaultInit(
+                new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
+                {
+                    InstanceIdentifier = numElementsIdentifier
+                }, context);
+        }
 
+        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
+        {
             if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
             {
+                // If the parameter is marshalled by-value [Out] or [In, Out], then we need to unmarshal the contents of the collection
+                // into the passed-in collection value.
                 yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
                 yield break;
             }
 
-            if (!_shape.HasFlag(MarshallerShape.ToManaged))
+            if (context.Direction == MarshalDirection.UnmanagedToManaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
             {
+                // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
+                // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
+                yield return _elementsMarshalling.GenerateClearUnmanagedValuesSource(info, context);
                 yield break;
             }
-            else
+
+            if (!_shape.HasFlag(MarshallerShape.ToManaged))
             {
-                // int <numElements> = <numElementsExpression>;
-                yield return LocalDeclarationStatement(
-                    VariableDeclaration(
-                        PredefinedType(Token(SyntaxKind.IntKeyword)),
-                        SingletonSeparatedList(
-                            VariableDeclarator(numElementsIdentifier)
-                                .WithInitializer(EqualsValueClause(_numElementsExpression)))));
-
-                yield return _elementsMarshalling.GenerateUnmarshalStatement(info, context);
+                yield break;
             }
 
+            string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+
+            // <numElements> = <numElementsExpression>;
+            yield return ExpressionStatement(
+                AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                    IdentifierName(numElementsIdentifier),
+                    _numElementsExpression));
+
+            yield return _elementsMarshalling.GenerateUnmarshalStatement(info, context);
+
             foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, context))
             {
                 yield return statement;
index fd9459c..6c10134 100644 (file)
@@ -227,7 +227,7 @@ namespace Microsoft.Interop
                 }
                 else
                 {
-                    // <nativeIdentifier> = <marshallerType>.ConvertToUnmanaged(<managedIdentifier>, <originalValueIdentifier>__buffer);
+                    // <nativeIdentifier> = <marshallerType>.ConvertToUnmanaged(<managedIdentifier>, <nativeIdentifier>__buffer);
                     yield return ExpressionStatement(
                         AssignmentExpression(
                             SyntaxKind.SimpleAssignmentExpression,
@@ -295,134 +295,6 @@ namespace Microsoft.Interop
         public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
     }
 
-    internal sealed class StatelessUnmanagedToManagedOwnershipTracking : ICustomTypeMarshallingStrategy
-    {
-        internal const string OwnOriginalValueIdentifier = "ownOriginal";
-        internal const string OriginalValueIdentifier = "original";
-
-        private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
-
-        public StatelessUnmanagedToManagedOwnershipTracking(ICustomTypeMarshallingStrategy innerMarshaller)
-        {
-            _innerMarshaller = innerMarshaller;
-        }
-
-        public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
-
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
-        public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
-        {
-            foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context))
-            {
-                yield return statement;
-            }
-
-            // Now that we've set the new value to pass to the caller on the <native> identifier, we need to make sure that we free the old one.
-            // The caller will not see the old one any more, so it won't be able to free it.
-
-            // <ownOriginalValue> = true;
-            yield return ExpressionStatement(
-                AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
-                    IdentifierName(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)),
-                    LiteralExpression(SyntaxKind.TrueLiteralExpression)));
-        }
-
-        public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
-        public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
-        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
-        {
-            foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context))
-            {
-                yield return statement;
-            }
-
-            // bool <ownOriginalValue> = false;
-            yield return LocalDeclarationStatement(
-                VariableDeclaration(
-                    PredefinedType(Token(SyntaxKind.BoolKeyword)),
-                    SingletonSeparatedList(
-                        VariableDeclarator(
-                            Identifier(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)),
-                            null,
-                            EqualsValueClause(
-                                LiteralExpression(SyntaxKind.FalseLiteralExpression))))));
-
-            // <nativeType> <original> = <originalValueIdentifier>;
-            yield return LocalDeclarationStatement(
-                VariableDeclaration(
-                    AsNativeType(info).Syntax,
-                    SingletonSeparatedList(
-                        VariableDeclarator(
-                            Identifier(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)),
-                            null,
-                            EqualsValueClause(
-                                IdentifierName(context.GetIdentifiers(info).native))))));
-        }
-
-        public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
-        public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
-    }
-
-    internal sealed class FreeOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy
-    {
-        private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
-
-        public FreeOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
-        {
-            _innerMarshaller = innerMarshaller;
-        }
-
-        public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
-
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
-        {
-            // if (<ownOriginalValue>)
-            // {
-            //     <cleanup>
-            // }
-            yield return IfStatement(
-                IdentifierName(context.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OwnOriginalValueIdentifier)),
-                Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context))));
-        }
-
-        public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
-        public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
-        public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
-        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
-
-        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
-        public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
-
-        private sealed record OwnedValueCodeContext(StubCodeContext InnerContext) : StubCodeContext
-        {
-            public override bool SingleFrameSpansNativeContext => InnerContext.SingleFrameSpansNativeContext;
-
-            public override bool AdditionalTemporaryStateLivesAcrossStages => InnerContext.AdditionalTemporaryStateLivesAcrossStages;
-
-            public override (TargetFramework framework, Version version) GetTargetFramework() => InnerContext.GetTargetFramework();
-
-            public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
-            {
-                var (managed, _) = InnerContext.GetIdentifiers(info);
-                return (managed, InnerContext.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OriginalValueIdentifier));
-            }
-
-            public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => InnerContext.GetAdditionalIdentifier(info, name);
-        }
-    }
-
     /// <summary>
     /// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec.
     /// </summary>
@@ -638,20 +510,6 @@ namespace Microsoft.Interop
                     IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesDestination)),
                 ArgumentList(SingletonSeparatedList(Argument(IdentifierName(managedIdentifier)))));
         }
-
-        public StatementSyntax GetManagedValuesNumElementsAssignment(TypePositionInfo info, StubCodeContext context)
-        {
-            var numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
-            // <numElements> = <GetManagedValuesSource>.Length;
-            return ExpressionStatement(
-                AssignmentExpression(
-                    SyntaxKind.SimpleAssignmentExpression,
-                    IdentifierName(numElementsIdentifier),
-                    MemberAccessExpression(
-                        SyntaxKind.SimpleMemberAccessExpression,
-                        GetManagedValuesSource(info, context),
-                        IdentifierName("Length"))));
-        }
     }
 
     /// <summary>
@@ -660,14 +518,14 @@ namespace Microsoft.Interop
     internal sealed class StatelessLinearCollectionMarshalling : ICustomTypeMarshallingStrategy
     {
         private readonly ICustomTypeMarshallingStrategy _spaceMarshallingStrategy;
-        private readonly IElementsMarshalling _elementsMarshalling;
+        private readonly ElementsMarshalling _elementsMarshalling;
         private readonly ManagedTypeInfo _unmanagedType;
         private readonly MarshallerShape _shape;
         private readonly bool _cleanupElementsAndSpace;
 
         public StatelessLinearCollectionMarshalling(
             ICustomTypeMarshallingStrategy spaceMarshallingStrategy,
-            IElementsMarshalling elementsMarshalling,
+            ElementsMarshalling elementsMarshalling,
             ManagedTypeInfo unmanagedType,
             MarshallerShape shape,
             bool cleanupElementsAndSpace)
@@ -705,21 +563,31 @@ namespace Microsoft.Interop
 
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
         {
-            foreach (var statement in _spaceMarshallingStrategy.GenerateMarshalStatements(info, context))
+            if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
             {
-                yield return statement;
-            }
-            if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
+                // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
+                // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
+                yield return _elementsMarshalling.GenerateClearManagedSource(info, context);
                 yield break;
+            }
 
-            if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
+            if (context.Direction == MarshalDirection.UnmanagedToManaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
             {
-                yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
+                // If the parameter is marshalled by-value [Out] or [In, Out], then we need to unmarshal the contents of the collection
+                // into the passed-in collection value.
+                yield return _elementsMarshalling.GenerateUnmanagedToManagedByValueOutMarshalStatement(info, context);
+                yield break;
             }
-            else
+
+            foreach (var statement in _spaceMarshallingStrategy.GenerateMarshalStatements(info, context))
             {
-                yield return _elementsMarshalling.GenerateMarshalStatement(info, context);
+                yield return statement;
             }
+
+            if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
+                yield break;
+
+            yield return _elementsMarshalling.GenerateMarshalStatement(info, context);
         }
 
         public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _spaceMarshallingStrategy.GenerateNotifyForSuccessfulInvokeStatements(info, context);
@@ -734,10 +602,20 @@ namespace Microsoft.Interop
         {
             if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
             {
+                // If the parameter is marshalled by-value [Out] or [In, Out], then we need to unmarshal the contents of the collection
+                // into the passed-in collection value.
                 yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
                 yield break;
             }
 
+            if (context.Direction == MarshalDirection.UnmanagedToManaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
+            {
+                // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
+                // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
+                yield return _elementsMarshalling.GenerateClearUnmanagedValuesSource(info, context);
+                yield break;
+            }
+
             if (!_shape.HasFlag(MarshallerShape.ToManaged))
             {
                 yield break;
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs
new file mode 100644 (file)
index 0000000..4343547
--- /dev/null
@@ -0,0 +1,220 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace Microsoft.Interop
+{
+    /// <summary>
+    /// Marshalling strategy that introduces a variable to hold the initial value of the provided <see cref="TypePositionInfo"/> and a variable to track if the original value has been replaced.
+    /// </summary>
+    /// <seealso cref="CleanupOwnedOriginalValueMarshalling" />
+    internal sealed class UnmanagedToManagedOwnershipTrackingStrategy : ICustomTypeMarshallingStrategy
+    {
+        private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
+
+        public UnmanagedToManagedOwnershipTrackingStrategy(ICustomTypeMarshallingStrategy innerMarshaller)
+        {
+            _innerMarshaller = innerMarshaller;
+        }
+
+        public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
+
+        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context))
+            {
+                yield return statement;
+            }
+
+            // Now that we've set the new value to pass to the caller on the <native> identifier, we need to make sure that we free the old one.
+            // The caller will not see the old one any more, so it won't be able to free it.
+
+            // <ownOriginalValue> = true;
+            yield return ExpressionStatement(
+                AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                    IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)),
+                    LiteralExpression(SyntaxKind.TrueLiteralExpression)));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
+        public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context))
+            {
+                yield return statement;
+            }
+
+            // bool <ownOriginalValue> = false;
+            yield return LocalDeclarationStatement(
+                VariableDeclaration(
+                    PredefinedType(Token(SyntaxKind.BoolKeyword)),
+                    SingletonSeparatedList(
+                        VariableDeclarator(
+                            Identifier(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)),
+                            null,
+                            EqualsValueClause(
+                                LiteralExpression(SyntaxKind.FalseLiteralExpression))))));
+
+            yield return OwnershipTrackingHelpers.DeclareOriginalValueIdentifier(info, context, AsNativeType(info));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
+        public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
+    }
+
+    /// <summary>
+    /// Marshalling strategy that uses the tracking variables introduced by <see cref="UnmanagedToManagedOwnershipTrackingStrategy"/> to cleanup the original value if the original value is owned
+    /// in the <see cref="StubCodeContext.Stage.Cleanup"/> stage.
+    /// </summary>
+    internal sealed class CleanupOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy
+    {
+        private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
+
+        public CleanupOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
+        {
+            _innerMarshaller = innerMarshaller;
+        }
+
+        public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
+
+        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            // if (<ownOriginalValue>)
+            // {
+            //     <cleanup>
+            // }
+            yield return IfStatement(
+                IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)),
+                Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context))));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
+        public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
+
+        public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
+        public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
+    }
+
+    /// <summary>
+    /// Marshalling strategy to cache the initial value of a given <see cref="TypePositionInfo"/> in a local variable and cleanup that value in the cleanup stage.
+    /// Useful in scenarios where the value is always owned in all code-paths that reach the <see cref="StubCodeContext.Stage.Cleanup"/> stage, so additional ownership tracking is extraneous.
+    /// </summary>
+    internal sealed class FreeAlwaysOwnedOriginalValueGenerator : IMarshallingGenerator
+    {
+        private readonly IMarshallingGenerator _inner;
+
+        public FreeAlwaysOwnedOriginalValueGenerator(IMarshallingGenerator inner)
+        {
+            _inner = inner;
+        }
+
+        public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _inner.AsNativeType(info);
+        public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
+        {
+            if (context.CurrentStage == StubCodeContext.Stage.Setup)
+            {
+                return GenerateSetupStatements();
+            }
+
+            if (context.CurrentStage == StubCodeContext.Stage.Cleanup)
+            {
+                return GenerateStatementsFromInner(new OwnedValueCodeContext(context));
+            }
+
+            return GenerateStatementsFromInner(context);
+
+            IEnumerable<StatementSyntax> GenerateSetupStatements()
+            {
+                foreach (var statement in GenerateStatementsFromInner(context))
+                {
+                    yield return statement;
+                }
+
+                yield return OwnershipTrackingHelpers.DeclareOriginalValueIdentifier(info, context, AsNativeType(info));
+            }
+
+            IEnumerable<StatementSyntax> GenerateStatementsFromInner(StubCodeContext contextForStage)
+            {
+                return _inner.Generate(info, contextForStage);
+            }
+        }
+
+        public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info) => _inner.GetNativeSignatureBehavior(info);
+        public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context) => _inner.GetValueBoundaryBehavior(info, context);
+        public bool IsSupported(TargetFramework target, Version version) => _inner.IsSupported(target, version);
+        public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) => _inner.SupportsByValueMarshalKind(marshalKind, context);
+        public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _inner.UsesNativeIdentifier(info, context);
+    }
+
+#pragma warning disable SA1400 // Access modifier should be declared https://github.com/DotNetAnalyzers/StyleCopAnalyzers/issues/3659
+    sealed file record OwnedValueCodeContext : StubCodeContext
+#pragma warning restore SA1400 // Access modifier should be declared
+    {
+        private readonly StubCodeContext _innerContext;
+
+        public OwnedValueCodeContext(StubCodeContext innerContext)
+        {
+            _innerContext = innerContext;
+            CurrentStage = innerContext.CurrentStage;
+            Direction = innerContext.Direction;
+        }
+
+        public override bool SingleFrameSpansNativeContext => _innerContext.SingleFrameSpansNativeContext;
+
+        public override bool AdditionalTemporaryStateLivesAcrossStages => _innerContext.AdditionalTemporaryStateLivesAcrossStages;
+
+        public override (TargetFramework framework, Version version) GetTargetFramework() => _innerContext.GetTargetFramework();
+
+        public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
+        {
+            var (managed, _) = _innerContext.GetIdentifiers(info);
+            return (managed, _innerContext.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OriginalValueIdentifier));
+        }
+
+        public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => _innerContext.GetAdditionalIdentifier(info, name);
+    }
+
+#pragma warning disable SA1400 // Access modifier should be declared https://github.com/DotNetAnalyzers/StyleCopAnalyzers/issues/3659
+    static file class OwnershipTrackingHelpers
+#pragma warning restore SA1400 // Access modifier should be declared
+    {
+        public const string OwnOriginalValueIdentifier = "ownOriginal";
+        public const string OriginalValueIdentifier = "original";
+
+        public static StatementSyntax DeclareOriginalValueIdentifier(TypePositionInfo info, StubCodeContext context, ManagedTypeInfo nativeType)
+        {
+            // <nativeType> <original> = <nativeValueIdentifier>;
+            return LocalDeclarationStatement(
+                VariableDeclaration(
+                    nativeType.Syntax,
+                    SingletonSeparatedList(
+                        VariableDeclarator(
+                            Identifier(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)),
+                            null,
+                            EqualsValueClause(
+                                IdentifierName(context.GetIdentifiers(info).native))))));
+        }
+    }
+}
index 224ea68..54e96f3 100644 (file)
@@ -198,14 +198,13 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")]
         public unsafe void ValidateArrayElementsByValueOutFreed_Stateless()
         {
             const int startingValue = 13;
 
             ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
 
-            void* wrapper = VTableGCHandlePair<NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObjectStateful>.Allocate(impl);
+            void* wrapper = VTableGCHandlePair<NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject>.Allocate(impl);
 
             try
             {
@@ -222,7 +221,7 @@ namespace ComInterfaceGenerator.Tests
             }
             finally
             {
-                VTableGCHandlePair<NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObjectStateful>.Free(wrapper);
+                VTableGCHandlePair<NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject>.Free(wrapper);
             }
         }
 
@@ -282,7 +281,6 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")]
         public unsafe void ValidateArrayElementsByValueOutFreed_Stateful()
         {
             const int startingValue = 13;
@@ -392,7 +390,7 @@ namespace ComInterfaceGenerator.Tests
 
                 public Span<TManaged> GetManagedValuesDestination(int numElements)
                 {
-                    return _managed = new TManaged[numElements];
+                    return _managed ??= new TManaged[numElements];
                 }
 
                 public ReadOnlySpan<TUnmanaged> GetUnmanagedValuesSource(int numElements)