Cleanup caller allocated and callee allocated resources separately (#89982)
authorJackson Schuster <36744439+jtschuster@users.noreply.github.com>
Tue, 8 Aug 2023 21:47:45 +0000 (16:47 -0500)
committerGitHub <noreply@github.com>
Tue, 8 Aug 2023 21:47:45 +0000 (14:47 -0700)
This PR separates cleaning up caller allocated resources and callee allocated resources into separate stages in the managed to unmanaged direction. Caller allocated parameters (anything except 'out') will clean up the same way. Callee allocated parameters ('out' parameters) will be cleaned up only if the invocation succeeded.

22 files changed:
docs/design/libraries/LibraryImportGenerator/Pipeline.md
src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs
src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionMarshallerFactory.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/SafeHandleMarshaller.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IArrayOfStatelessElements.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulFinallyMarshalling.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs [new file with mode: 0644]

index 9533241..bd38953 100644 (file)
@@ -91,8 +91,12 @@ The stub code generator itself will handle some initial setup and variable decla
     - Call `Generate` on the marshalling generator for every parameter
 1. `GuaranteedUnmarshal`: conversion of native to managed data even when an exception is thrown
     - Call `Generate` on the marshalling generator for every parameter.
-1. `Cleanup`: free any allocated resources
+    - If this stage has any statements, put them in an if statement where the condition represents whether the call succeeded
+1. `CleanupCallerAllocated`: free any resources allocated by the caller
     - Call `Generate` on the marshalling generator for every parameter
+1. `CleanupCalleeAllocated`: if the native method succeeded, free any resources allocated by the callee (`out` parameters and return values)
+    - Call `Generate` on the marshalling generator for every parameter
+    - If this stage has any statements, put them in an if statement where the condition represents whether the call succeeded
 
 Generated P/Invoke structure (if no code is generated for `GuaranteedUnmarshal` and `Cleanup`, the `try-finally` is omitted):
 ```C#
@@ -113,7 +117,8 @@ try
 finally
 {
     << GuaranteedUnmarshal >>
-    << Cleanup >>
+    << CleanupCalleeAllocated >>
+    << CleanupCallerAllocated >>
 }
 ```
 
@@ -138,12 +143,12 @@ Support for these features is indicated in code by the `abstract` `SingleFrameSp
 
 The various scenarios mentioned above have different levels of support for these specialized features:
 
-| Scenarios | Pinning and Stack allocation across the native context | Storing additional temporary state in locals |
-|------|-----|-----|
-| P/Invoke | supported | supported |
-| Reverse P/Invoke | unsupported | supported |
-| User-defined structure content marshalling | unsupported | unsupported |
-| non-blittable array marshalling | unsupported | unuspported |
+| Scenarios                                  | Pinning and Stack allocation across the native context | Storing additional temporary state in locals |
+|--------------------------------------------|--------------------------------------------------------|----------------------------------------------|
+| P/Invoke                                   | supported                                              | supported                                    |
+| Reverse P/Invoke                           | unsupported                                            | supported                                    |
+| User-defined structure content marshalling | unsupported                                            | unsupported                                  |
+| non-blittable array marshalling            | unsupported                                            | unuspported                                  |
 
 To help enable developers to use the full model described in the [Struct Marshalling design](./StructMarshalling.md), we declare that in contexts where `AdditionalTemporaryStateLivesAcrossStages` is false, developers can still assume that state declared in the `Setup` phase is valid in any phase, but any side effects in code emitted in a phase other than `Setup` will not be guaranteed to be visible in other phases. This enables developers to still use the identifiers declared in the `Setup` phase in their other phases, but they'll need to take care to design their generators to handle these rules.
 
index 9221bef..0a18241 100644 (file)
@@ -2,13 +2,13 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Linq;
 using System.Collections.Generic;
 using System.Collections.Immutable;
+using System.Linq;
+using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop.JavaScript
 {
@@ -61,13 +61,13 @@ namespace Microsoft.Interop.JavaScript
         {
             StatementSyntax invoke = InvokeSyntax();
             GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context);
-            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty;
+            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty;
             VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables);
 
             var setupStatements = new List<StatementSyntax>();
             SetupSyntax(setupStatements);
 
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true));
             }
@@ -81,7 +81,7 @@ namespace Microsoft.Interop.JavaScript
 
             tryStatements.Add(invoke);
 
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                     IdentifierName(InvokeSucceededIdentifier),
@@ -94,12 +94,12 @@ namespace Microsoft.Interop.JavaScript
 
             List<StatementSyntax> allStatements = setupStatements;
             List<StatementSyntax> finallyStatements = new List<StatementSyntax>();
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
-                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal)));
+                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated))));
             }
 
-            finallyStatements.AddRange(statements.Cleanup);
+            finallyStatements.AddRange(statements.CleanupCallerAllocated);
             if (finallyStatements.Count > 0)
             {
                 allStatements.Add(
index 2c1d49f..1f415b8 100644 (file)
@@ -67,14 +67,14 @@ namespace Microsoft.Interop.JavaScript
         {
             StatementSyntax invoke = InvokeSyntax();
             GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context);
-            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty;
+            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty;
             VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForManagedToUnmanaged(_marshallers, _context, shouldInitializeVariables);
 
             var setupStatements = new List<StatementSyntax>();
             BindSyntax(setupStatements);
             SetupSyntax(setupStatements);
 
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true));
             }
@@ -88,7 +88,7 @@ namespace Microsoft.Interop.JavaScript
             tryStatements.AddRange(statements.PinnedMarshal);
 
             tryStatements.Add(invoke);
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                     IdentifierName(InvokeSucceededIdentifier),
@@ -100,12 +100,12 @@ namespace Microsoft.Interop.JavaScript
 
             List<StatementSyntax> allStatements = setupStatements;
             List<StatementSyntax> finallyStatements = new List<StatementSyntax>();
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
-                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal)));
+                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated))));
             }
 
-            finallyStatements.AddRange(statements.Cleanup);
+            finallyStatements.AddRange(statements.CleanupCallerAllocated);
             if (finallyStatements.Count > 0)
             {
                 // Add try-finally block if there are any statements in the finally block
index 883ca85..dd18f01 100644 (file)
@@ -131,7 +131,7 @@ namespace Microsoft.Interop
                         BracketedArgumentList(SingletonSeparatedList(
                             Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(index)))))),
                     callConv));
-            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty;
+            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty;
             VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForManagedToUnmanaged(_marshallers, _context, shouldInitializeVariables);
 
             if (_setLastError)
@@ -143,7 +143,7 @@ namespace Microsoft.Interop
                     initializeToDefault: false));
             }
 
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true));
             }
@@ -174,7 +174,7 @@ namespace Microsoft.Interop
             tryStatements.AddRange(statements.NotifyForSuccessfulInvoke);
 
             // <invokeSucceeded> = true;
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                     IdentifierName(InvokeSucceededIdentifier),
@@ -197,12 +197,12 @@ namespace Microsoft.Interop
 
             List<StatementSyntax> allStatements = setupStatements;
             List<StatementSyntax> finallyStatements = new List<StatementSyntax>();
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
-                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal)));
+                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated))));
             }
 
-            finallyStatements.AddRange(statements.Cleanup);
+            finallyStatements.AddRange(statements.CleanupCallerAllocated);
             if (finallyStatements.Count > 0)
             {
                 // Add try-finally block if there are any statements in the finally block
index e396ea5..63017cf 100644 (file)
@@ -83,7 +83,7 @@ namespace Microsoft.Interop
             {
                 Debug.Assert(info.MarshallingAttributeInfo is ManagedHResultExceptionMarshallingInfo);
 
-                if (context.CurrentStage != StubCodeContext.Stage.Unmarshal)
+                if (context.CurrentStage != StubCodeContext.Stage.NotifyForSuccessfulInvoke)
                 {
                     yield break;
                 }
index 64060d6..cc76d85 100644 (file)
@@ -4,6 +4,7 @@
 using System;
 using System.Collections.Generic;
 using System.Collections.Immutable;
+using System.Diagnostics;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
@@ -52,9 +53,11 @@ namespace Microsoft.Interop
                 _marshallers,
                 _context,
                 methodToInvoke);
+            Debug.Assert(statements.CleanupCalleeAllocated.IsEmpty);
+
             bool shouldInitializeVariables =
                 !statements.GuaranteedUnmarshal.IsEmpty
-                || !statements.Cleanup.IsEmpty
+                || !statements.CleanupCallerAllocated.IsEmpty
                 || !statements.ManagedExceptionCatchClauses.IsEmpty;
             VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables);
 
@@ -77,7 +80,7 @@ namespace Microsoft.Interop
 
             SyntaxList<CatchClauseSyntax> catchClauses = List(statements.ManagedExceptionCatchClauses);
 
-            finallyStatements.AddRange(statements.Cleanup);
+            finallyStatements.AddRange(statements.CleanupCallerAllocated);
             if (finallyStatements.Count > 0)
             {
                 allStatements.Add(
index 34d09a7..939a623 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Linq;
 using System.Collections.Generic;
 using System.Collections.Immutable;
 using Microsoft.CodeAnalysis.CSharp;
@@ -107,7 +108,7 @@ namespace Microsoft.Interop
         public BlockSyntax GeneratePInvokeBody(string dllImportName)
         {
             GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context, IdentifierName(dllImportName));
-            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty;
+            bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty;
             VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForManagedToUnmanaged(_marshallers, _context, shouldInitializeVariables);
 
             var setupStatements = new List<StatementSyntax>();
@@ -121,7 +122,7 @@ namespace Microsoft.Interop
                     initializeToDefault: false));
             }
 
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true));
             }
@@ -148,7 +149,7 @@ namespace Microsoft.Interop
             }
             tryStatements.Add(statements.Pin.NestFixedStatements(fixedBlock));
             // <invokeSucceeded> = true;
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
                 tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                     IdentifierName(InvokeSucceededIdentifier),
@@ -160,12 +161,12 @@ namespace Microsoft.Interop
 
             List<StatementSyntax> allStatements = setupStatements;
             List<StatementSyntax> finallyStatements = new List<StatementSyntax>();
-            if (!statements.GuaranteedUnmarshal.IsEmpty)
+            if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty))
             {
-                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal)));
+                finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated))));
             }
 
-            finallyStatements.AddRange(statements.Cleanup);
+            finallyStatements.AddRange(statements.CleanupCallerAllocated);
             if (finallyStatements.Count > 0)
             {
                 // Add try-finally block if there are any statements in the finally block
index e8e3de0..d611eb5 100644 (file)
@@ -21,7 +21,8 @@ namespace Microsoft.Interop
         public ImmutableArray<StatementSyntax> Unmarshal { get; init; }
         public ImmutableArray<StatementSyntax> NotifyForSuccessfulInvoke { get; init; }
         public ImmutableArray<StatementSyntax> GuaranteedUnmarshal { get; init; }
-        public ImmutableArray<StatementSyntax> Cleanup { get; init; }
+        public ImmutableArray<StatementSyntax> CleanupCallerAllocated { get; init; }
+        public ImmutableArray<StatementSyntax> CleanupCalleeAllocated { get; init; }
 
         public ImmutableArray<CatchClauseSyntax> ManagedExceptionCatchClauses { get; init; }
 
@@ -38,7 +39,8 @@ namespace Microsoft.Interop
                             .AddRange(GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal })),
                 NotifyForSuccessfulInvoke = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.NotifyForSuccessfulInvoke }),
                 GuaranteedUnmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.GuaranteedUnmarshal }),
-                Cleanup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Cleanup }),
+                CleanupCallerAllocated = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.CleanupCallerAllocated }),
+                CleanupCalleeAllocated = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.CleanupCalleeAllocated }),
                 ManagedExceptionCatchClauses = GenerateCatchClauseForManagedException(marshallers, context)
             };
         }
@@ -182,7 +184,8 @@ namespace Microsoft.Interop
                 StubCodeContext.Stage.Invoke => "Call the P/Invoke.",
                 StubCodeContext.Stage.UnmarshalCapture => "Capture the native data into marshaller instances in case conversion to managed data throws an exception.",
                 StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.",
-                StubCodeContext.Stage.Cleanup => "Perform required cleanup.",
+                StubCodeContext.Stage.CleanupCallerAllocated => "Perform cleanup of caller allocated resources.",
+                StubCodeContext.Stage.CleanupCalleeAllocated => "Perform cleanup of callee allocated resources.",
                 StubCodeContext.Stage.NotifyForSuccessfulInvoke => "Keep alive any managed objects that need to stay alive across the call.",
                 StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.",
                 _ => throw new ArgumentOutOfRangeException(nameof(stage))
index c992f63..30f333d 100644 (file)
@@ -97,8 +97,10 @@ namespace Microsoft.Interop
                         return _nativeTypeMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
                     }
                     break;
-                case StubCodeContext.Stage.Cleanup:
-                    return _nativeTypeMarshaller.GenerateCleanupStatements(info, context);
+                case StubCodeContext.Stage.CleanupCallerAllocated:
+                    return _nativeTypeMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context);
+                case StubCodeContext.Stage.CleanupCalleeAllocated:
+                    return _nativeTypeMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context);
                 default:
                     break;
             }
index 7b71e9e..136de0e 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System;
 using System.Collections.Generic;
 using System.Linq;
 using Microsoft.CodeAnalysis;
@@ -440,7 +441,7 @@ namespace Microsoft.Interop
                 indexConstraintName,
                 _elementInfo,
                 _elementMarshaller,
-                StubCodeContext.Stage.Cleanup);
+                context.CurrentStage);
 
             if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement))
             {
@@ -531,6 +532,18 @@ namespace Microsoft.Interop
                     .WithInitializer(EqualsValueClause(
                         CollectionSource.GetManagedValuesDestination(info, context))))));
 
+            StubCodeContext.Stage[] stagesToGenerate;
+
+            // Until we separate CalleeAllocated cleanup and CallerAllocated cleanup in unmanaged to managed, we'll need this hack
+            if (context.Direction is MarshalDirection.UnmanagedToManaged && info.ByValueContentsMarshalKind is ByValueContentsMarshalKind.Out)
+            {
+                stagesToGenerate = new[] { StubCodeContext.Stage.Marshal, StubCodeContext.Stage.PinnedMarshal };
+            }
+            else
+            {
+                stagesToGenerate = new[] { StubCodeContext.Stage.Marshal, StubCodeContext.Stage.PinnedMarshal, StubCodeContext.Stage.CleanupCallerAllocated, StubCodeContext.Stage.CleanupCalleeAllocated };
+            }
+
             return Block(
                 setNumElements,
                 unmanagedValuesSource,
@@ -541,9 +554,7 @@ namespace Microsoft.Interop
                     IdentifierName(numElementsIdentifier),
                     _elementInfo,
                     new FreeAlwaysOwnedOriginalValueGenerator(_elementMarshaller),
-                    StubCodeContext.Stage.Marshal,
-                    StubCodeContext.Stage.PinnedMarshal,
-                    StubCodeContext.Stage.Cleanup));
+                    stagesToGenerate));
         }
 
         private static List<StatementSyntax> GenerateElementStages(
index f9da5d3..5322365 100644 (file)
@@ -13,7 +13,9 @@ namespace Microsoft.Interop
     {
         ManagedTypeInfo AsNativeType(TypePositionInfo info);
 
-        IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context);
+        IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context);
+
+        IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context);
 
         IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context);
 
index ebb1f01..325f474 100644 (file)
@@ -406,6 +406,26 @@ namespace Microsoft.Interop
         }
 
         /// <summary>
+        /// Returns which stage cleanup should be performed for the parameter.
+        /// </summary>
+        public static StubCodeContext.Stage GetCleanupStage(TypePositionInfo info, StubCodeContext context)
+        {
+            // Unmanaged to managed doesn't properly handle lifetimes right now and will default to the original behavior.
+            // Failures will only occur when marshalling fails, and would only cause leaks, not double frees.
+            // See https://github.com/dotnet/runtime/issues/89483 for more details
+            if (context.Direction is MarshalDirection.UnmanagedToManaged)
+                return StubCodeContext.Stage.CleanupCallerAllocated;
+
+            return GetMarshalDirection(info, context) switch
+            {
+                MarshalDirection.UnmanagedToManaged => StubCodeContext.Stage.CleanupCalleeAllocated,
+                MarshalDirection.ManagedToUnmanaged => StubCodeContext.Stage.CleanupCallerAllocated,
+                MarshalDirection.Bidirectional => StubCodeContext.Stage.CleanupCallerAllocated,
+                _ => throw new UnreachableException()
+            };
+        }
+
+        /// <summary>
         /// Ensure that the count of a collection is available at call time if the parameter is not an out parameter.
         /// It only looks at an indirection level of 0 (the size of the outer array), so there are some holes in
         /// analysis if the parameter is a multidimensional array, but that case seems very unlikely to be hit.
@@ -417,10 +437,10 @@ namespace Microsoft.Interop
             if (stubDirection is MarshalDirection.ManagedToUnmanaged)
                 return;
 
-            if (info.MarshallingAttributeInfo is NativeLinearCollectionMarshallingInfo collectionMarshallingInfo
-                && collectionMarshallingInfo.ElementCountInfo is CountElementCountInfo countInfo
-                && !(info.RefKind is RefKind.Out
-                    || info.ManagedIndex is TypePositionInfo.ReturnIndex))
+            if (!(info.RefKind is RefKind.Out
+                    || info.ManagedIndex is TypePositionInfo.ReturnIndex)
+                && info.MarshallingAttributeInfo is NativeLinearCollectionMarshallingInfo collectionMarshallingInfo
+                && collectionMarshallingInfo.ElementCountInfo is CountElementCountInfo countInfo)
             {
                 if (countInfo.ElementInfo.IsByRef && countInfo.ElementInfo.RefKind is RefKind.Out)
                 {
@@ -444,6 +464,5 @@ namespace Microsoft.Interop
                 // If the parameter is multidimensional and a higher indirection level parameter is ByValue [Out], then we should warn.
             }
         }
-
     }
 }
index 4e0b3bc..9f4700e 100644 (file)
@@ -214,7 +214,7 @@ namespace Microsoft.Interop
                                         IdentifierName(newHandleObjectIdentifier)))));
                     }
                     break;
-                case StubCodeContext.Stage.Cleanup:
+                case StubCodeContext.Stage.CleanupCallerAllocated:
                     if (!info.IsManagedReturnPosition && (!info.IsByRef || info.RefKind == RefKind.In))
                     {
                         yield return IfStatement(
index edb226f..22c1c33 100644 (file)
@@ -30,8 +30,28 @@ namespace Microsoft.Interop
 
         public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated)
+                yield break;
+
+            if (!_shape.HasFlag(MarshallerShape.Free))
+                yield break;
+
+            // <marshaller>.Free();
+            yield return ExpressionStatement(
+                InvocationExpression(
+                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                        IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)),
+                        IdentifierName(ShapeMemberNames.Free)),
+                    ArgumentList()));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated)
+                yield break;
+
             if (!_shape.HasFlag(MarshallerShape.Free))
                 yield break;
 
@@ -213,9 +233,14 @@ namespace Microsoft.Interop
             return _innerMarshaller.AsNativeType(info);
         }
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
-            return _innerMarshaller.GenerateCleanupStatements(info, context);
+            return _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context);
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            return _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context);
         }
 
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
@@ -371,12 +396,26 @@ namespace Microsoft.Interop
         }
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
+            // We don't have anything to cleanup specifically related to this value, just the elements. We let the element marshaller decide whether to cleanup in callee or caller cleanup stage
             if (!_cleanupElements)
-            {
                 yield break;
+
+            StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
+
+            if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
+            {
+                yield return elementCleanup;
             }
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            // We don't have anything to cleanup specifically related to this value, just the elements. We let the element marshaller decide whether to cleanup in callee or caller cleanup stage
+            if (!_cleanupElements)
+                yield break;
 
             StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
 
@@ -385,6 +424,7 @@ namespace Microsoft.Interop
                 yield return elementCleanup;
             }
         }
+
         public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
 
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
@@ -504,13 +544,36 @@ namespace Microsoft.Interop
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            foreach (var statement in _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context))
+            {
+                yield return statement;
+            }
+
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated)
+                yield break;
+
+            string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
+            // <marshaller>.Free();
+            yield return ExpressionStatement(
+                InvocationExpression(
+                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                        IdentifierName(marshaller),
+                        IdentifierName(ShapeMemberNames.Free)),
+                    ArgumentList()));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
-            foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context))
+            foreach (var statement in _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context))
             {
                 yield return statement;
             }
 
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated)
+                yield break;
+
             string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
             // <marshaller>.Free();
             yield return ExpressionStatement(
index 6303abe..2ed2653 100644 (file)
@@ -33,7 +33,9 @@ namespace Microsoft.Interop
 
         public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
 
         public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
         {
@@ -159,7 +161,8 @@ namespace Microsoft.Interop
         }
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context);
         public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
 
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
@@ -266,9 +269,31 @@ namespace Microsoft.Interop
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
-            foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context))
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated)
+                yield break;
+
+            foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context))
+            {
+                yield return statement;
+            }
+            // <marshallerType>.Free(<nativeIdentifier>);
+            yield return ExpressionStatement(
+                InvocationExpression(
+                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                        _marshallerType,
+                        IdentifierName(ShapeMemberNames.Free)),
+                    ArgumentList(SingletonSeparatedList(
+                        Argument(IdentifierName(context.GetIdentifiers(info).native))))));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated)
+                yield break;
+
+            foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context))
             {
                 yield return statement;
             }
@@ -316,8 +341,25 @@ namespace Microsoft.Interop
             return _unmanagedType;
         }
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated)
+                yield break;
+
+            string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            // <numElements> = <numElementsExpression>;
+            yield return ExpressionStatement(
+                AssignmentExpression(
+                    SyntaxKind.SimpleAssignmentExpression,
+                    IdentifierName(numElementsIdentifier),
+                    _numElementsExpression));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated)
+                yield break;
+
             if (MarshallerHelpers.GetMarshalDirection(info, context) == MarshalDirection.ManagedToUnmanaged)
             {
                 yield return EmptyStatement();
@@ -325,6 +367,7 @@ namespace Microsoft.Interop
             }
 
             string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+            // <numElements> = <numElementsExpression>;
             yield return ExpressionStatement(
                 AssignmentExpression(
                     SyntaxKind.SimpleAssignmentExpression,
@@ -397,6 +440,7 @@ namespace Microsoft.Interop
         public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
         public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
         {
+            // int <numElements>;
             string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
             yield return LocalDeclarationStatement(
                 VariableDeclaration(
@@ -554,12 +598,13 @@ namespace Microsoft.Interop
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _unmanagedType;
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
             if (!_cleanupElementsAndSpace)
             {
                 yield break;
             }
+
             StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
 
             if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
@@ -567,6 +612,7 @@ namespace Microsoft.Interop
                 // If we don't have the numElements variable still available from unmarshal or marshal stage, we need to reassign that again
                 if (!context.AdditionalTemporaryStateLivesAcrossStages)
                 {
+                    // <numElements> = <numElementsExpression>;
                     string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
                     yield return ExpressionStatement(
                         AssignmentExpression(
@@ -577,9 +623,45 @@ namespace Microsoft.Interop
                 yield return elementCleanup;
             }
 
-            foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupStatements(info, context))
+            if (MarshallerHelpers.GetCleanupStage(info, context) is StubCodeContext.Stage.CleanupCallerAllocated)
             {
-                yield return statement;
+                foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupCallerAllocatedResourcesStatements(info, context))
+                {
+                    yield return statement;
+                }
+            }
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            if (!_cleanupElementsAndSpace)
+            {
+                yield break;
+            }
+            StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
+
+            if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
+            {
+                // If we don't have the numElements variable still available from unmarshal or marshal stage, we need to reassign that again
+                if (!context.AdditionalTemporaryStateLivesAcrossStages)
+                {
+                    // <numElements> = <numElementsExpression>;
+                    string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
+                    yield return ExpressionStatement(
+                        AssignmentExpression(
+                            SyntaxKind.SimpleAssignmentExpression,
+                            IdentifierName(numElementsIdentifier),
+                            _numElementsExpression));
+                }
+                yield return elementCleanup;
+            }
+
+            if (MarshallerHelpers.GetCleanupStage(info, context) is StubCodeContext.Stage.CleanupCallerAllocated)
+            {
+                foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupCalleeAllocatedResourcesStatements(info, context))
+                {
+                    yield return statement;
+                }
             }
         }
 
@@ -646,6 +728,10 @@ namespace Microsoft.Interop
                 // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
                 // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
                 yield return _elementsMarshalling.GenerateClearManagedValuesDestination(info, context);
+                foreach (var statement in _spaceMarshallingStrategy.GenerateUnmarshalStatements(info, context))
+                {
+                    yield return statement;
+                }
                 yield break;
             }
 
index 5b9d086..ba43659 100644 (file)
@@ -3,7 +3,6 @@
 
 using System;
 using System.Collections.Generic;
-using System.Diagnostics.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
@@ -25,7 +24,8 @@ namespace Microsoft.Interop
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context);
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context);
 
         public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
         public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
@@ -78,7 +78,7 @@ namespace Microsoft.Interop
 
     /// <summary>
     /// Marshalling strategy that uses the tracking variables introduced by <see cref="UnmanagedToManagedOwnershipTrackingStrategy"/> to cleanup the original value if the original value is owned
-    /// in the <see cref="StubCodeContext.Stage.Cleanup"/> stage.
+    /// in the <see cref="StubCodeContext.Stage.CleanupCallerAllocated"/> stage.
     /// </summary>
     internal sealed class CleanupOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy
     {
@@ -91,15 +91,30 @@ namespace Microsoft.Interop
 
         public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
 
-        public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+        public IEnumerable<StatementSyntax> GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
         {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated)
+                yield break;
             // if (<ownOriginalValue>)
             // {
             //     <cleanup>
             // }
             yield return IfStatement(
                 IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)),
-                Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context))));
+                Block(_innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, new OwnedValueCodeContext(context))));
+        }
+
+        public IEnumerable<StatementSyntax> GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context)
+        {
+            if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated)
+                yield break;
+            // if (<ownOriginalValue>)
+            // {
+            //     <cleanup>
+            // }
+            yield return IfStatement(
+                IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)),
+                Block(_innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, new OwnedValueCodeContext(context))));
         }
 
         public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
@@ -119,7 +134,7 @@ namespace Microsoft.Interop
 
     /// <summary>
     /// Marshalling strategy to cache the initial value of a given <see cref="TypePositionInfo"/> in a local variable and cleanup that value in the cleanup stage.
-    /// Useful in scenarios where the value is always owned in all code-paths that reach the <see cref="StubCodeContext.Stage.Cleanup"/> stage, so additional ownership tracking is extraneous.
+    /// Useful in scenarios where the value is always owned in all code-paths that reach the <see cref="StubCodeContext.Stage.CleanupCallerAllocated"/> stage, so additional ownership tracking is extraneous.
     /// </summary>
     internal sealed class FreeAlwaysOwnedOriginalValueGenerator : IMarshallingGenerator
     {
@@ -138,7 +153,7 @@ namespace Microsoft.Interop
                 return GenerateSetupStatements();
             }
 
-            if (context.CurrentStage == StubCodeContext.Stage.Cleanup)
+            if (context.CurrentStage == StubCodeContext.Stage.CleanupCallerAllocated)
             {
                 return GenerateStatementsFromInner(new OwnedValueCodeContext(context));
             }
index 35e07f1..50af716 100644 (file)
@@ -64,9 +64,14 @@ namespace Microsoft.Interop
             NotifyForSuccessfulInvoke,
 
             /// <summary>
-            /// Perform any cleanup required
+            /// Perform any cleanup required on caller allocated resources
             /// </summary>
-            Cleanup,
+            CleanupCallerAllocated,
+
+            /// <summary>
+            /// Perform any cleanup required on callee allocated resources
+            /// </summary>
+            CleanupCalleeAllocated,
 
             /// <summary>
             /// Convert native data to managed data even in the case of an exception during
index 5157e3c..98390d9 100644 (file)
@@ -85,6 +85,63 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
+        public void IArrayOfStatelessElements()
+        {
+            var obj = CreateWrapper<ArrayOfStatelessElements, IArrayOfStatelessElements>();
+            var data = new StatelessType[10];
+
+            // ByValueContentsOut should only free the returned values
+            var oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            obj.MethodContentsOut(data, data.Length);
+            Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount);
+
+            // ByValueContentsOut should only free the elements after the call
+            oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            obj.MethodContentsIn(data, data.Length);
+            Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount);
+
+            // ByValueContentsInOut should free elements in both directions
+            oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            obj.MethodContentsInOut(data, data.Length);
+            Assert.Equal(oldFreeCount + 20, StatelessTypeMarshaller.AllFreeCount);
+        }
+
+        [Fact]
+        public void IArrayOfStatelessElementsThrows()
+        {
+            var obj = CreateWrapper<ArrayOfStatelessElementsThrows, IArrayOfStatelessElements>();
+            var data = new StatelessType[10];
+            var oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            try
+            {
+                obj.MethodContentsOut(data, 10);
+            }
+            catch (Exception) { }
+            Assert.Equal(oldFreeCount, StatelessTypeMarshaller.AllFreeCount);
+
+            for (int i = 0; i < 10; i++)
+            {
+                data[i] = new StatelessType() { I = i };
+            }
+
+            oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            try
+            {
+                obj.MethodContentsIn(data, 10);
+            }
+            catch (Exception) { }
+            Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount);
+
+            oldFreeCount = StatelessTypeMarshaller.AllFreeCount;
+            try
+            {
+                obj.MethodContentsInOut(data, 10);
+            }
+            catch (Exception) { }
+            Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount);
+        }
+
+        [Fact]
         public void IJaggedIntArray()
         {
             int[][] data = new int[][] { new int[] { 1, 2, 3 }, new int[] { 4, 5 }, new int[] { 6, 7, 8, 9 } };
@@ -138,35 +195,36 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/89747")]
         public void ICollectionMarshallingFails()
         {
+            Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception);
+
             var obj = CreateWrapper<ICollectionMarshallingFailsImpl, ICollectionMarshallingFails>();
 
             Assert.Throws<MarshallingFailureException>(() =>
-                _ = obj.GetConstSize()
+                obj.Set(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 }, 10)
             );
 
-            Assert.Throws<MarshallingFailureException>(() =>
-                _ = obj.Get(out _)
+            Assert.Throws(hrExceptionType, () =>
+                _ = obj.GetConstSize()
             );
 
-            Assert.Throws<MarshallingFailureException>(() =>
-                obj.Set(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 }, 10)
+            Assert.Throws(hrExceptionType, () =>
+                _ = obj.Get(out _)
             );
         }
 
         [Fact]
-        [ActiveIssue("https://github.com/dotnet/runtime/issues/89747")]
         public void IJaggedArrayMarshallingFails()
         {
+            Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception);
             var obj = CreateWrapper<IJaggedIntArrayMarshallingFailsImpl, IJaggedIntArrayMarshallingFails>();
 
-            Assert.Throws<MarshallingFailureException>(() =>
+            Assert.Throws(hrExceptionType, () =>
                 _ = obj.GetConstSize()
             );
 
-            Assert.Throws<MarshallingFailureException>(() =>
+            Assert.Throws(hrExceptionType, () =>
                 _ = obj.Get(out _, out _)
             );
             var array = new int[][] { new int[] { 1, 2, 3 }, new int[] { 4, 5, }, new int[] { 6, 7, 8, 9 } };
index abb8842..658606d 100644 (file)
@@ -72,4 +72,16 @@ namespace SharedTypes.ComInterfaces
             }
         }
     }
+
+    [GeneratedComClass]
+    internal partial class ArrayOfStatelessElementsThrows : IArrayOfStatelessElements
+    {
+        public void Method(StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodContentsIn(StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodContentsInOut(StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodContentsOut(StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodIn(in StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodOut(out StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+        public void MethodRef(ref StatelessType[] param, int size) => throw new ManagedComMethodFailureException();
+    }
 }
index 37fce22..3c9951c 100644 (file)
@@ -19,6 +19,7 @@ namespace SharedTypes.ComInterfaces
         [PreserveSig]
         StatefulFinallyType ReturnPreserveSig();
     }
+
     [GeneratedComClass]
     internal partial class StatefulFinallyMarshalling : IStatefulFinallyMarshalling
     {
index 12c98b6..66ccc0c 100644 (file)
@@ -48,6 +48,8 @@ namespace SharedTypes.ComInterfaces
     [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))]
     internal static class StatelessTypeMarshaller
     {
+        public static int AllFreeCount => Bidirectional.FreeCount + UnmanagedToManaged.FreeCount + ManagedToUnmanaged.FreeCount;
+
         internal static class Bidirectional
         {
             public static int FreeCount { get; private set; }
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs
new file mode 100644 (file)
index 0000000..de5aea6
--- /dev/null
@@ -0,0 +1,11 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+
+namespace SharedTypes.ComInterfaces
+{
+    internal class ManagedComMethodFailureException : Exception
+    {
+    }
+}