[mono][wasm] marshal-ilgen is dropped when not required (#86035)
authorJan Dupej <109523496+jandupej@users.noreply.github.com>
Wed, 14 Jun 2023 10:51:44 +0000 (12:51 +0200)
committerGitHub <noreply@github.com>
Wed, 14 Jun 2023 10:51:44 +0000 (12:51 +0200)
* Replicating naricc's PInvokeScanner.

* MarshalingPInvokeScanner now detects and outputs incompatible assemblies.

* Allowing void return type.

* PInvoke

* Two-pass searching in progress.

* Second pass resolves inconclusive types.

* Cleanup.

* Modifying the wasm toolchain to omit marshal-ilgen when possible.

* Hopefully fix incorrect app dir.

* Added definitions to MarshalingPInvokeScannerPath hopefully where needed.

* Adding definitions of MarshalingPInvokeScannerPath to more locations.

* Adding missing references to PInvoke scanner.

* Changed task ordering, assemblies list.

* Removed metadata load context.

* Fixed code analyzer issues.

* Fixed file name.

* Moved MarshalingPInvokeScanner to MonoTargetsTask.

* Removed BlazorApp.

* Implemented more marshaling validation rules, removed warning message that got "promoted" to an error.

* Catching bad image exceptions, giving reason for requiring marshal-ilgen.

* Cleaned up references to standalone MarshalingPInvokeScanner project, now that the analyzer is in MonoTargtesTask.

* More cleanup.

* Fixed P/Invoke return value in marshal-lightweight.

* Removed incompatible assemblies listing.

* Restoring minimal functionality to marshal-ilgen-stub.

* Addressed feedback.

* Tweaked identification of blittable types. Added explanation to Compatibility enum.

* Moved PInvokeCollector.cs and SignatureMapper.cs back to WasmAppBuilder.

* Addressed feedback.

src/mono/mono/component/marshal-ilgen-stub.c
src/mono/mono/metadata/marshal-lightweight.c
src/mono/wasm/build/WasmApp.Native.targets
src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs [new file with mode: 0644]
src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs [new file with mode: 0644]
src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj
src/tasks/WasmAppBuilder/PInvokeCollector.cs [new file with mode: 0644]
src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs

index 39a9b4e..277261c 100644 (file)
@@ -9,11 +9,46 @@ marshal_ilgen_available (void)
        return false;
 }
 
+static void emit_throw_exception (MonoMarshalLightweightCallbacks* lightweight_cb, 
+               MonoMethodBuilder* mb, const char* exc_nspace, const char* exc_name, const char* msg)
+{
+       lightweight_cb->mb_emit_exception (mb, exc_nspace, exc_name, msg);
+}
+
 static int
-stub_emit_marshal_ilgen (EmitMarshalContext *m, int argnum, MonoType *t,
-               MonoMarshalSpec *spec, int conv_arg,    
-               MonoType **conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb)
+stub_emit_marshal_ilgen (EmitMarshalContext* m, int argnum, MonoType* t,
+               MonoMarshalSpecspec, int conv_arg,    
+               MonoType** conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb)
 {
+       if (spec) {
+               g_assert (spec->native != MONO_NATIVE_ASANY);
+               g_assert (spec->native != MONO_NATIVE_CUSTOM);
+       }
+       
+       g_assert (!m_type_is_byref(t));
+
+       switch (t->type) {
+       case MONO_TYPE_PTR:
+       case MONO_TYPE_I1:
+       case MONO_TYPE_U1:
+       case MONO_TYPE_I2:
+       case MONO_TYPE_U2:
+       case MONO_TYPE_I4:
+       case MONO_TYPE_U4:
+       case MONO_TYPE_I:
+       case MONO_TYPE_U:
+       case MONO_TYPE_R4:
+       case MONO_TYPE_R8:
+       case MONO_TYPE_I8:
+       case MONO_TYPE_U8:
+       case MONO_TYPE_FNPTR:
+               return lightweight_cb->emit_marshal_scalar (m, argnum, t, spec, conv_arg, conv_arg_type, action);
+       default:
+               emit_throw_exception (lightweight_cb, m->mb, "System", "ApplicationException",
+                       g_strdup("Cannot marshal nonblittlable types without marshal-ilgen."));
+               break;
+       }
+
        return 0;
 }
 
index a4fb731..8871b2f 100644 (file)
@@ -523,7 +523,7 @@ emit_runtime_invoke_body_ilgen (MonoMethodBuilder *mb, const char **param_names,
        emit_thread_force_interrupt_checkpoint (mb);
        emit_invoke_call (mb, method, sig, callsig, loc_res, virtual_, need_direct_wrapper);
 
-       mono_mb_emit_ldloc (mb, 0);
+       mono_mb_emit_ldloc (mb, loc_res);
        mono_mb_emit_byte (mb, CEE_RET);
 }
 
index 85b5e18..8093fb9 100644 (file)
@@ -3,9 +3,11 @@
 
   <UsingTask TaskName="Microsoft.WebAssembly.Build.Tasks.ManagedToNativeGenerator" AssemblyFile="$(WasmAppBuilderTasksAssemblyPath)" />
   <UsingTask TaskName="Microsoft.WebAssembly.Build.Tasks.EmccCompile" AssemblyFile="$(WasmAppBuilderTasksAssemblyPath)" />
+  <UsingTask TaskName="MonoTargetsTasks.MarshalingPInvokeScanner" AssemblyFile="$(MonoTargetsTasksAssemblyPath)" />
 
   <PropertyGroup>
     <_WasmBuildNativeCoreDependsOn>
+      _ScanAssembliesDecideLightweightMarshaler;
       _WasmAotCompileApp;
       _WasmStripAOTAssemblies;
       _PrepareForWasmBuildNative;
@@ -33,7 +35,7 @@
   <ItemGroup Condition="'$(Configuration)' == 'Debug' and '@(_MonoComponent->Count())' == 0">
     <_MonoComponent Include="hot_reload;debugger" />
   </ItemGroup>
-  <ItemGroup>
+  <ItemGroup Condition="'@(MonoLightweightMarshallerIncompatibleAssemblies->Count())' > 0">
     <_MonoComponent Include="marshal-ilgen" />
   </ItemGroup>
 
     </ItemGroup>
   </Target>
 
+  <Target Name="_ScanAssembliesDecideLightweightMarshaler">
+     <ItemGroup>
+      <AssembliesToScan Include="@(_WasmAssembliesInternal)" />
+    </ItemGroup>
+    
+    <MarshalingPInvokeScanner Assemblies ="@(AssembliesToScan)">
+      <Output TaskParameter="IncompatibleAssemblies" ItemName="MonoLightweightMarshallerIncompatibleAssemblies" />
+    </MarshalingPInvokeScanner>
+  </Target>
+
   <!-- '$(ArchiveTests)' != 'true' is to skip on CI for now -->
   <Target Name="_WasmStripAOTAssemblies" Condition="'$(_WasmShouldAOT)' == 'true' and '$(WasmStripAOTAssemblies)' == 'true' and '$(AOTMode)' != 'LLVMOnlyInterp' and '$(ArchiveTests)' != 'true'">
     <PropertyGroup>
diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs
new file mode 100644 (file)
index 0000000..de5e6e3
--- /dev/null
@@ -0,0 +1,157 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection.Metadata;
+using System.Reflection.Metadata.Ecma335;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Reflection;
+using System.Reflection.PortableExecutable;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+
+namespace MonoTargetsTasks
+{
+    public class MarshalingPInvokeScanner : Task
+    {
+        [Required]
+        public string[] Assemblies { get; set; } = Array.Empty<string>();
+
+        [Output]
+        public string[]? IncompatibleAssemblies { get; private set; }
+
+        public override bool Execute()
+        {
+            if (Assemblies is null || Assemblies!.Length == 0)
+            {
+                Log.LogError($"{nameof(MarshalingPInvokeScanner)}.{nameof(Assemblies)} cannot be empty");
+                return false;
+            }
+
+            try
+            {
+                ExecuteInternal();
+                return !Log.HasLoggedErrors;
+            }
+            catch (LogAsErrorException e)
+            {
+                Log.LogError(e.Message);
+                return false;
+            }
+        }
+
+        private void ExecuteInternal()
+        {
+            IncompatibleAssemblies = ScanAssemblies(Assemblies);
+        }
+
+        private string[] ScanAssemblies(string[] assemblies)
+        {
+            HashSet<string> incompatible = new HashSet<string>();
+            MinimalMarshalingTypeCompatibilityProvider mmtcp = new(Log);
+            foreach (string aname in assemblies)
+            {
+                if (IsAssemblyIncompatible(aname, mmtcp))
+                    incompatible.Add(aname);
+            }
+
+            if (mmtcp.IsSecondPassNeeded)
+            {
+                foreach (string aname in assemblies)
+                    ResolveInconclusiveTypes(incompatible, aname, mmtcp);
+            }
+
+            return incompatible.ToArray();
+        }
+
+        private static string GetMethodName(MetadataReader mr, MethodDefinition md) => mr.GetString(md.Name);
+
+        private void ResolveInconclusiveTypes(HashSet<string> incompatible, string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp)
+        {
+            string assyName = MetadataReader.GetAssemblyName(assyPath).Name!;
+            HashSet<string> inconclusiveTypes = mmtcp.GetInconclusiveTypesForAssembly(assyName);
+            if(inconclusiveTypes.Count == 0)
+                return;
+
+            using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
+            using PEReader peReader = new PEReader(file);
+            MetadataReader mdtReader = peReader.GetMetadataReader();
+
+            SignatureDecoder<Compatibility, object> decoder = new(mmtcp, mdtReader, null!);
+
+            foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions)
+            {
+                TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle);
+                string fullTypeName = string.Join(":", mdtReader.GetString(typeDef.Namespace), mdtReader.GetString(typeDef.Name));
+
+                // This is not perfect, but should work right for enums defined in other assemblies,
+                // which is the only case where we use Compatibility.Inconclusive.
+                if (inconclusiveTypes.Contains(fullTypeName) &&
+                    mmtcp.GetTypeFromDefinition(mdtReader, typeDefHandle, 0) != Compatibility.Compatible)
+                {
+                    Log.LogMessage(MessageImportance.Low, string.Format("Type {0} is marshaled and requires marshal-ilgen.", fullTypeName));
+
+                    incompatible.Add("(unknown assembly)");
+                }
+            }
+        }
+
+        private bool IsAssemblyIncompatible(string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp)
+        {
+            using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
+            using PEReader peReader = new PEReader(file);
+            MetadataReader mdtReader = peReader.GetMetadataReader();
+
+            foreach(CustomAttributeHandle attrHandle in mdtReader.CustomAttributes)
+            {
+                CustomAttribute attr = mdtReader.GetCustomAttribute(attrHandle);
+
+                if(attr.Constructor.Kind == HandleKind.MethodDefinition)
+                {
+                    MethodDefinitionHandle mdh = (MethodDefinitionHandle)attr.Constructor;
+                    MethodDefinition md = mdtReader.GetMethodDefinition(mdh);
+                    TypeDefinitionHandle tdh = md.GetDeclaringType();
+                    TypeDefinition td = mdtReader.GetTypeDefinition(tdh);
+
+                    if(mdtReader.GetString(td.Namespace) == "System.Runtime.CompilerServices" &&
+                        mdtReader.GetString(td.Name) == "DisableRuntimeMarshallingAttribute")
+                        return false;
+                }
+            }
+
+            foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions)
+            {
+                TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle);
+                string ns = mdtReader.GetString(typeDef.Namespace);
+                string name = mdtReader.GetString(typeDef.Name);
+
+                foreach(MethodDefinitionHandle mthDefHandle in typeDef.GetMethods())
+                {
+                    MethodDefinition mthDef = mdtReader.GetMethodDefinition(mthDefHandle);
+                    if(!mthDef.Attributes.HasFlag(MethodAttributes.PinvokeImpl))
+                        continue;
+
+                    BlobReader sgnBlobReader = mdtReader.GetBlobReader(mthDef.Signature);
+                    SignatureDecoder<Compatibility, object> decoder = new(mmtcp, mdtReader, null!);
+
+                    MethodSignature<Compatibility> sgn = decoder.DecodeMethodSignature(ref sgnBlobReader);
+                    if(sgn.ReturnType == Compatibility.Incompatible || sgn.ParameterTypes.Any(p => p == Compatibility.Incompatible))
+                    {
+                        Log.LogMessage(MessageImportance.Low, string.Format("Assembly {0} requires marhsal-ilgen for method {1}.{2}:{3} (first pass).",
+                            assyPath, ns, name, mdtReader.GetString(mthDef.Name)));
+
+                        return true;
+                    }
+                }
+            }
+
+            return false;
+        }
+    }
+}
diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs
new file mode 100644 (file)
index 0000000..8916c1d
--- /dev/null
@@ -0,0 +1,167 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection.Metadata;
+using System.Reflection.Metadata.Ecma335;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Reflection;
+using System.Reflection.PortableExecutable;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+
+namespace MonoTargetsTasks
+{
+    // For some valuetypes we cannot determine if they are compatible with disabled
+    // runtime marshaling without first resolving their base types. In this case we
+    // first mark the assembly as Inconclusive and do a second pass over the collected
+    // base type references in order to decide. If the base types are System.Enum,
+    // then the valuetypes are enumerations, and are compatible.
+    internal enum Compatibility
+    {
+        Compatible,
+        Incompatible,
+        Inconclusive
+    }
+
+    internal sealed class InconclusiveCompatibilityCollection
+    {
+        private readonly Dictionary<string, HashSet<string>> _data = new();
+
+        public bool IsEmpty => _data.Count == 0;
+
+        public void Add(string assyName, string namespaceName, string typeName)
+        {
+            HashSet<string>? incAssyTypes;
+
+            if(!_data.TryGetValue(assyName, out incAssyTypes))
+            {
+                incAssyTypes = new();
+                _data.Add(assyName, incAssyTypes);
+            }
+
+            incAssyTypes.Add($"{namespaceName}:{typeName}");
+        }
+
+        public HashSet<string> EnumerateForAssembly(string assyName)
+        {
+            if(_data.TryGetValue(assyName, out HashSet<string>? incAssyTypes))
+                return incAssyTypes!;
+
+            return new HashSet<string>();
+        }
+    }
+
+    internal sealed class MinimalMarshalingTypeCompatibilityProvider : ISignatureTypeProvider<Compatibility, object>
+    {
+        internal MinimalMarshalingTypeCompatibilityProvider(TaskLoggingHelper log)
+        {
+          _log = log;
+        }
+
+        private readonly TaskLoggingHelper _log;
+
+        // assembly name -> set of types needed for second pass
+        private readonly InconclusiveCompatibilityCollection _inconclusive = new();
+
+        public bool IsSecondPassNeeded => !_inconclusive.IsEmpty;
+        public HashSet<string> GetInconclusiveTypesForAssembly(string assyName) => _inconclusive.EnumerateForAssembly(assyName);
+
+        public Compatibility GetArrayType(Compatibility elementType, ArrayShape shape) => Compatibility.Incompatible;
+        public Compatibility GetByReferenceType(Compatibility elementType) => Compatibility.Incompatible;
+        public Compatibility GetFunctionPointerType(MethodSignature<Compatibility> signature) => Compatibility.Compatible;
+        public Compatibility GetGenericInstantiation(Compatibility genericType, ImmutableArray<Compatibility> typeArguments) => genericType;
+        public Compatibility GetGenericMethodParameter(object genericContext, int index) => Compatibility.Incompatible;
+        public Compatibility GetGenericTypeParameter(object genericContext, int index) => Compatibility.Incompatible;
+        public Compatibility GetModifiedType(Compatibility modifier, Compatibility unmodifiedType, bool isRequired) => Compatibility.Incompatible;
+        public Compatibility GetPinnedType(Compatibility elementType) => Compatibility.Compatible;
+        public Compatibility GetPointerType(Compatibility elementType) => Compatibility.Compatible;
+        public Compatibility GetPrimitiveType(PrimitiveTypeCode typeCode)
+        {
+            return typeCode switch
+            {
+            PrimitiveTypeCode.Object => Compatibility.Incompatible,
+            PrimitiveTypeCode.String => Compatibility.Incompatible,
+            PrimitiveTypeCode.TypedReference => Compatibility.Incompatible,
+            _ => Compatibility.Compatible
+            };
+        }
+
+        public Compatibility GetSZArrayType(Compatibility elementType) => Compatibility.Incompatible;
+
+        public Compatibility GetTypeFromDefinition(MetadataReader reader, TypeDefinitionHandle handle, byte rawTypeKind)
+        {
+            TypeDefinition typeDef = reader.GetTypeDefinition(handle);
+            if (reader.GetString(typeDef.Namespace) == "System" &&
+                reader.GetString(typeDef.Name) == "Enum")
+                return Compatibility.Compatible;
+
+            try
+            {
+                EntityHandle baseTypeHandle = typeDef.BaseType;
+                if (baseTypeHandle.Kind == HandleKind.TypeReference)
+                {
+                    TypeReference baseType = reader.GetTypeReference((TypeReferenceHandle)baseTypeHandle);
+                    if (reader.GetString(typeDef.Namespace) == "System" &&
+                        reader.GetString(baseType.Name) == "Enum")
+                        return Compatibility.Compatible;
+                }
+                else if (baseTypeHandle.Kind == HandleKind.TypeSpecification)
+                {
+                    TypeSpecification specInner = reader.GetTypeSpecification((TypeSpecificationHandle)baseTypeHandle);
+                    return specInner.DecodeSignature<Compatibility, object>(this, new object());
+                }
+                else if (baseTypeHandle.Kind == HandleKind.TypeDefinition)
+                {
+                    TypeDefinitionHandle handleInner = (TypeDefinitionHandle)baseTypeHandle;
+                    if (handle != handleInner)
+                        return GetTypeFromDefinition(reader, handleInner, rawTypeKind);
+                }
+            }
+            catch(BadImageFormatException ex)
+            {
+                _log.LogMessage(MessageImportance.Low, ex.Message);
+            }
+
+            return Compatibility.Incompatible;
+        }
+
+        public Compatibility GetTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind)
+        {
+            if (rawTypeKind == 0x11 /*ELEMENT_TYPE_VALUETYPE*/)
+            {
+                TypeReference typeRef = reader.GetTypeReference(handle);
+                EntityHandle scope = typeRef.ResolutionScope;
+
+                if (scope.Kind == HandleKind.AssemblyReference)
+                {
+                    AssemblyReferenceHandle assyRefHandle = (AssemblyReferenceHandle)typeRef.ResolutionScope;
+                    AssemblyReference assyRef = reader.GetAssemblyReference(assyRefHandle);
+
+                    _inconclusive.Add(assyName: reader.GetString(assyRef.Name),
+                        namespaceName: reader.GetString(typeRef.Namespace), typeName: reader.GetString(typeRef.Name));
+                    return Compatibility.Inconclusive;
+                }
+                else
+                {
+                    throw new NotImplementedException(string.Format("Unsupported ResolutionScope kind '{0}' used in type {1}:{2}.",
+                        scope.Kind.ToString(), reader.GetString(typeRef.Namespace), reader.GetString(typeRef.Name)));
+                }
+            }
+
+            return Compatibility.Incompatible;
+        }
+
+        public Compatibility GetTypeFromSpecification(MetadataReader reader, object genericContext, TypeSpecificationHandle handle, byte rawTypeKind)
+        {
+            TypeSpecification spec = reader.GetTypeSpecification((TypeSpecificationHandle)handle);
+            return spec.DecodeSignature<Compatibility, object>(this, genericContext);
+        }
+    }
+}
index 181538f..340cc4e 100644 (file)
@@ -17,6 +17,7 @@
     <PackageReference Include="System.Reflection.Metadata" Version="$(SystemReflectionMetadataVersion)" PrivateAssets="All" />
     <!-- These versions should not be newer than what Visual Studio MSBuild uses -->
     <PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsVersion)" PrivateAssets="all" />
+    <PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
   </ItemGroup>
 
   <ItemGroup>
@@ -30,6 +31,8 @@
     <Compile Include="JsonToItemsTaskFactory\JsonToItemsTaskFactory.cs" />
     <Compile Include="NetTraceToMibcConverterTask\NetTraceToMibcConverter.cs" />
     <Compile Include="..\Common\LogAsErrorException.cs" />
+    <Compile Include="MarshalingPInvokeScanner\MinimalMarshalingTypeCompatibilityProvider.cs" />
+    <Compile Include="MarshalingPInvokeScanner\MarshalingPInvokeScanner.cs" />
     <Compile Include="$(RepoRoot)src\libraries\System.Private.CoreLib\src\System\Diagnostics\CodeAnalysis\NullableAttributes.cs" Condition="'$(TargetFrameworkIdentifier)' == '.NETFramework'" />
   </ItemGroup>
 
diff --git a/src/tasks/WasmAppBuilder/PInvokeCollector.cs b/src/tasks/WasmAppBuilder/PInvokeCollector.cs
new file mode 100644 (file)
index 0000000..0eeecab
--- /dev/null
@@ -0,0 +1,250 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System;
+using System.Linq;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+using Microsoft.Build.Tasks;
+
+#pragma warning disable CA1067
+#pragma warning disable CS0649
+internal sealed class PInvoke : IEquatable<PInvoke>
+#pragma warning restore CA1067
+{
+    public PInvoke(string entryPoint, string module, MethodInfo method)
+    {
+        EntryPoint = entryPoint;
+        Module = module;
+        Method = method;
+    }
+
+    public string EntryPoint;
+    public string Module;
+    public MethodInfo Method;
+    public bool Skip;
+
+    public bool Equals(PInvoke? other)
+        => other != null &&
+            string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
+            string.Equals(Module, other.Module, StringComparison.Ordinal) &&
+            string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);
+
+    public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
+}
+#pragma warning restore CS0649
+
+internal sealed class PInvokeComparer : IEqualityComparer<PInvoke>
+{
+    public bool Equals(PInvoke? x, PInvoke? y)
+    {
+        if (x == null && y == null)
+            return true;
+        if (x == null || y == null)
+            return false;
+
+        return x.Equals(y);
+    }
+
+    public int GetHashCode(PInvoke pinvoke)
+        => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
+}
+
+
+internal sealed class PInvokeCollector {
+    private readonly Dictionary<Assembly, bool> _assemblyDisableRuntimeMarshallingAttributeCache = new();
+    private TaskLoggingHelper Log { get; init; }
+
+    public PInvokeCollector(TaskLoggingHelper log)
+    {
+        Log = log;
+    }
+
+    public void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callbacks, List<string> signatures, Type type)
+    {
+        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
+        {
+            try
+            {
+                CollectPInvokesForMethod(method);
+                if (DoesMethodHaveCallbacks(method))
+                    callbacks.Add(new PInvokeCallback(method));
+            }
+            catch (Exception ex) when (ex is not LogAsErrorException)
+            {
+                Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
+                        $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
+            }
+        }
+
+        if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"))
+        {
+            var method = type.GetMethod("Invoke");
+
+            if (method != null)
+            {
+                string? signature = SignatureMapper.MethodToSignature(method!);
+                if (signature == null)
+                    throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
+
+
+                Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
+                signatures.Add(signature);
+            }
+        }
+
+        void CollectPInvokesForMethod(MethodInfo method)
+        {
+            if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
+            {
+                var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
+                var module = (string)dllimport.ConstructorArguments[0].Value!;
+                var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
+                pinvokes.Add(new PInvoke(entrypoint, module, method));
+
+                string? signature = SignatureMapper.MethodToSignature(method);
+                if (signature == null)
+                {
+                    throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
+                }
+
+                Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
+                signatures.Add(signature);
+            }
+        }
+
+        bool DoesMethodHaveCallbacks(MethodInfo method)
+        {
+            if (!MethodHasCallbackAttributes(method))
+                return false;
+
+            if (TryIsMethodGetParametersUnsupported(method, out string? reason))
+            {
+                Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
+                        $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'.");
+                return false;
+            }
+
+            if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly))
+                return true;
+
+            // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are
+            // blittable
+            bool isVoid = method.ReturnType.FullName == "System.Void";
+            if (!isVoid && !IsBlittable(method.ReturnType))
+                Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
+
+            foreach (var p in method.GetParameters())
+            {
+                if (!IsBlittable(p.ParameterType))
+                    Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
+            }
+
+            return true;
+        }
+
+        static bool MethodHasCallbackAttributes(MethodInfo method)
+        {
+            foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
+            {
+                try
+                {
+                    if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
+                        cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
+                    {
+                        return true;
+                    }
+                }
+                catch
+                {
+                    // Assembly not found, ignore
+                }
+            }
+
+            return false;
+        }
+    }
+
+    public static bool IsBlittable(Type type)
+    {
+        if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum)
+            return true;
+        else
+            return false;
+    }
+
+    private static void Error(string msg) => throw new LogAsErrorException(msg);
+
+    private static bool HasAttribute(MemberInfo element, params string[] attributeNames)
+    {
+        foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element))
+        {
+            try
+            {
+                for (int i = 0; i < attributeNames.Length; ++i)
+                {
+                    if (cattr.AttributeType.FullName == attributeNames [i] ||
+                        cattr.AttributeType.Name == attributeNames[i])
+                    {
+                        return true;
+                    }
+                }
+            }
+            catch
+            {
+                // Assembly not found, ignore
+            }
+        }
+        return false;
+    }
+
+    private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotNullWhen(true)] out string? reason)
+    {
+        try
+        {
+            method.GetParameters();
+        }
+        catch (NotSupportedException nse)
+        {
+            reason = nse.Message;
+            return true;
+        }
+        catch
+        {
+            // not concerned with other exceptions
+        }
+
+        reason = null;
+        return false;
+    }
+
+    private bool HasAssemblyDisableRuntimeMarshallingAttribute(Assembly assembly)
+    {
+        if (!_assemblyDisableRuntimeMarshallingAttributeCache.TryGetValue(assembly, out var value))
+        {
+            _assemblyDisableRuntimeMarshallingAttributeCache[assembly] = value = assembly
+                .GetCustomAttributesData()
+                .Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute");
+        }
+
+       value = assembly.GetCustomAttributesData().Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute");
+
+        return value;
+    }
+}
+
+#pragma warning disable CS0649
+internal sealed class PInvokeCallback
+{
+    public PInvokeCallback(MethodInfo method)
+    {
+        Method = method;
+    }
+
+    public MethodInfo Method;
+    public string? EntryName;
+}
+#pragma warning restore CS0649
index b8ecfa7..6bec3a9 100644 (file)
@@ -35,6 +35,8 @@ internal sealed class PInvokeTableGenerator
         var pinvokes = new List<PInvoke>();
         var callbacks = new List<PInvokeCallback>();
 
+        PInvokeCollector pinvokeCollector = new(Log);
+
         var resolver = new PathAssemblyResolver(assemblies);
         using var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib");
 
@@ -46,7 +48,7 @@ internal sealed class PInvokeTableGenerator
             Log.LogMessage(MessageImportance.Low, $"Loading {asmPath} to scan for pinvokes");
             var a = mlc.LoadFromAssemblyPath(asmPath);
             foreach (var type in a.GetTypes())
-                CollectPInvokes(pinvokes, callbacks, signatures, type);
+                pinvokeCollector.CollectPInvokes(pinvokes, callbacks, signatures, type);
         }
 
         string tmpFileName = Path.GetTempFileName();
@@ -71,111 +73,6 @@ internal sealed class PInvokeTableGenerator
         return signatures;
     }
 
-    private void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callbacks, List<string> signatures, Type type)
-    {
-        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
-        {
-            try
-            {
-                CollectPInvokesForMethod(method);
-                if (DoesMethodHaveCallbacks(method))
-                    callbacks.Add(new PInvokeCallback(method));
-            }
-            catch (Exception ex) when (ex is not LogAsErrorException)
-            {
-                Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
-                        $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
-            }
-        }
-
-        if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"))
-        {
-            var method = type.GetMethod("Invoke");
-
-            if (method != null)
-            {
-                string? signature = SignatureMapper.MethodToSignature(method!);
-                if (signature == null)
-                    throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
-
-
-                Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
-                signatures.Add(signature);
-            }
-        }
-
-        void CollectPInvokesForMethod(MethodInfo method)
-        {
-            if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
-            {
-                var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
-                var module = (string)dllimport.ConstructorArguments[0].Value!;
-                var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
-                pinvokes.Add(new PInvoke(entrypoint, module, method));
-
-                string? signature = SignatureMapper.MethodToSignature(method);
-                if (signature == null)
-                {
-                    throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
-                }
-
-                Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
-                signatures.Add(signature);
-            }
-        }
-
-        bool DoesMethodHaveCallbacks(MethodInfo method)
-        {
-            if (!MethodHasCallbackAttributes(method))
-                return false;
-
-            if (TryIsMethodGetParametersUnsupported(method, out string? reason))
-            {
-                Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
-                        $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'.");
-                return false;
-            }
-
-            if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly))
-                return true;
-
-            // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are
-            // blittable
-            bool isVoid = method.ReturnType.FullName == "System.Void";
-            if (!isVoid && !IsBlittable(method.ReturnType))
-                Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
-
-            foreach (var p in method.GetParameters())
-            {
-                if (!IsBlittable(p.ParameterType))
-                    Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
-            }
-
-            return true;
-        }
-
-        static bool MethodHasCallbackAttributes(MethodInfo method)
-        {
-            foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
-            {
-                try
-                {
-                    if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
-                        cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
-                    {
-                        return true;
-                    }
-                }
-                catch
-                {
-                    // Assembly not found, ignore
-                }
-            }
-
-            return false;
-        }
-    }
-
     private static bool HasAttribute(MemberInfo element, params string[] attributeNames)
     {
         foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element))
@@ -516,55 +413,3 @@ internal sealed class PInvokeTableGenerator
 
     private static void Error(string msg) => throw new LogAsErrorException(msg);
 }
-
-#pragma warning disable CA1067
-internal sealed class PInvoke : IEquatable<PInvoke>
-#pragma warning restore CA1067
-{
-    public PInvoke(string entryPoint, string module, MethodInfo method)
-    {
-        EntryPoint = entryPoint;
-        Module = module;
-        Method = method;
-    }
-
-    public string EntryPoint;
-    public string Module;
-    public MethodInfo Method;
-    public bool Skip;
-
-    public bool Equals(PInvoke? other)
-        => other != null &&
-            string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
-            string.Equals(Module, other.Module, StringComparison.Ordinal) &&
-            string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);
-
-    public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
-}
-
-internal sealed class PInvokeComparer : IEqualityComparer<PInvoke>
-{
-    public bool Equals(PInvoke? x, PInvoke? y)
-    {
-        if (x == null && y == null)
-            return true;
-        if (x == null || y == null)
-            return false;
-
-        return x.Equals(y);
-    }
-
-    public int GetHashCode(PInvoke pinvoke)
-        => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
-}
-
-internal sealed class PInvokeCallback
-{
-    public PInvokeCallback(MethodInfo method)
-    {
-        Method = method;
-    }
-
-    public MethodInfo Method;
-    public string? EntryName;
-}