[wasm] Correctly escape library names when generating symbols for .c (#79007)
authorAnkit Jain <radical@gmail.com>
Wed, 30 Nov 2022 23:31:32 +0000 (18:31 -0500)
committerGitHub <noreply@github.com>
Wed, 30 Nov 2022 23:31:32 +0000 (18:31 -0500)
* [wasm] Correctly escape library names when generating symbols for .c files.
Use the existing `FixupSymbolName` method for fixing library names too,
when converting to symbols.

* [wasm] *TableGenerator task: Cache the symbol name fixups
.. as it is called frequently, and for repeated strings. For a
consolewasm template build, we get 490 calls but only 140 of them are
for unique strings.

* Add tests

Fixes https://github.com/dotnet/runtime/issues/78992 .

src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs
src/tasks/WasmAppBuilder/IcallTableGenerator.cs
src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs
src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs

index 4c42a30..5473667 100644 (file)
@@ -4,6 +4,7 @@
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
+using System.Text;
 using Xunit;
 using Xunit.Abstractions;
 
@@ -551,6 +552,72 @@ namespace Wasm.Build.Tests
             Assert.Contains("square: 25", output);
         }
 
+        [Theory]
+        [BuildAndRun(host: RunHost.Chrome, parameters: new object[] { new object[] {
+                "with-hyphen",
+                "with#hash-and-hyphen",
+                "with.per.iod",
+                "with🚀unicode#"
+            } })]
+
+        public void CallIntoLibrariesWithNonAlphanumericCharactersInTheirNames(BuildArgs buildArgs, string[] libraryNames, RunHost host, string id)
+        {
+            buildArgs = ExpandBuildArgs(buildArgs,
+                                        extraItems: @$"<NativeFileReference Include=""*.c"" />",
+                                        extraProperties: buildArgs.AOT
+                                                            ? string.Empty
+                                                            : "<WasmBuildNative>true</WasmBuildNative>");
+
+            int baseArg = 10;
+            (_, string output) = BuildProject(buildArgs,
+                                        id: id,
+                                        new BuildProjectOptions(
+                                            InitProject: () => GenerateSourceFiles(_projectDir!, baseArg),
+                                            Publish: buildArgs.AOT,
+                                            DotnetWasmFromRuntimePack: false
+                                            ));
+
+            output = RunAndTestWasmApp(buildArgs,
+                                       buildDir: _projectDir,
+                                       expectedExitCode: 42,
+                                       host: host,
+                                       id: id);
+
+            for (int i = 0; i < libraryNames.Length; i ++)
+            {
+                Assert.Contains($"square_{i}: {(i + baseArg) * (i + baseArg)}", output);
+            }
+
+            void GenerateSourceFiles(string outputPath, int baseArg)
+            {
+                StringBuilder csBuilder = new($@"
+                    using System;
+                    using System.Runtime.InteropServices;
+                ");
+
+                StringBuilder dllImportsBuilder = new();
+                for (int i = 0; i < libraryNames.Length; i ++)
+                {
+                    dllImportsBuilder.AppendLine($"[DllImport(\"{libraryNames[i]}\")] static extern int square_{i}(int x);");
+                    csBuilder.AppendLine($@"Console.WriteLine($""square_{i}: {{square_{i}({i + baseArg})}}"");");
+
+                    string nativeCode = $@"
+                        #include <stdarg.h>
+
+                        int square_{i}(int x)
+                        {{
+                            return x * x;
+                        }}";
+                    File.WriteAllText(Path.Combine(outputPath, $"{libraryNames[i]}.c"), nativeCode);
+                }
+
+                csBuilder.AppendLine("return 42;");
+                csBuilder.Append(dllImportsBuilder);
+
+                File.WriteAllText(Path.Combine(outputPath, "Program.cs"), csBuilder.ToString());
+            }
+        }
+
         private (BuildArgs, string) BuildForVariadicFunctionTests(string programText, BuildArgs buildArgs, string id, string? verbosity = null, string extraProperties = "")
         {
             extraProperties += "<AllowUnsafeBlocks>true</AllowUnsafeBlocks><_WasmDevel>true</_WasmDevel>";
index c40b6b4..ba6f900 100644 (file)
@@ -3,9 +3,6 @@
 
 using System;
 using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Diagnostics;
-using System.Diagnostics.CodeAnalysis;
 using System.IO;
 using System.Linq;
 using System.Text;
@@ -23,8 +20,13 @@ internal sealed class IcallTableGenerator
     private Dictionary<string, IcallClass> _runtimeIcalls = new Dictionary<string, IcallClass>();
 
     private TaskLoggingHelper Log { get; set; }
+    private readonly Func<string, string> _fixupSymbolName;
 
-    public IcallTableGenerator(TaskLoggingHelper log) => Log = log;
+    public IcallTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
+    {
+        Log = log;
+        _fixupSymbolName = fixupSymbolName;
+    }
 
     //
     // Given the runtime generated icall table, and a set of assemblies, generate
@@ -86,7 +88,7 @@ internal sealed class IcallTableGenerator
             if (assembly == "System.Private.CoreLib")
                 aname = "corlib";
             else
-                aname = assembly.Replace(".", "_");
+                aname = _fixupSymbolName(assembly);
             w.WriteLine($"#define ICALL_TABLE_{aname} 1\n");
 
             w.WriteLine($"static int {aname}_icall_indexes [] = {{");
index df48afa..1dff74a 100644 (file)
@@ -1,21 +1,13 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
 using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
-using System.IO;
 using System.Linq;
 using System.Text;
-using System.Text.Json;
-using System.Reflection;
 using Microsoft.Build.Framework;
 using Microsoft.Build.Utilities;
 
-#nullable enable
-
 public class ManagedToNativeGenerator : Task
 {
     [Required]
@@ -37,6 +29,11 @@ public class ManagedToNativeGenerator : Task
     [Output]
     public string[]? FileWrites { get; private set; }
 
+    private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
+
+    // Avoid sharing this cache with all the invocations of this task throughout the build
+    private readonly Dictionary<string, string> _symbolNameFixups = new();
+
     public override bool Execute()
     {
         if (Assemblies!.Length == 0)
@@ -65,8 +62,8 @@ public class ManagedToNativeGenerator : Task
 
     private void ExecuteInternal()
     {
-        var pinvoke = new PInvokeTableGenerator(Log);
-        var icall = new IcallTableGenerator(Log);
+        var pinvoke = new PInvokeTableGenerator(FixupSymbolName, Log);
+        var icall = new IcallTableGenerator(FixupSymbolName, Log);
 
         IEnumerable<string> cookies = Enumerable.Concat(
             pinvoke.Generate(PInvokeModules, Assemblies!, PInvokeOutputPath!),
@@ -80,4 +77,37 @@ public class ManagedToNativeGenerator : Task
             ? new string[] { PInvokeOutputPath, IcallOutputPath, InterpToNativeOutputPath }
             : new string[] { PInvokeOutputPath, InterpToNativeOutputPath };
     }
+
+    public string FixupSymbolName(string name)
+    {
+        if (_symbolNameFixups.TryGetValue(name, out string? fixedName))
+            return fixedName;
+
+        UTF8Encoding utf8 = new();
+        byte[] bytes = utf8.GetBytes(name);
+        StringBuilder sb = new();
+
+        foreach (byte b in bytes)
+        {
+            if ((b >= (byte)'0' && b <= (byte)'9') ||
+                (b >= (byte)'a' && b <= (byte)'z') ||
+                (b >= (byte)'A' && b <= (byte)'Z') ||
+                (b == (byte)'_'))
+            {
+                sb.Append((char)b);
+            }
+            else if (s_charsToReplace.Contains((char)b))
+            {
+                sb.Append('_');
+            }
+            else
+            {
+                sb.Append($"_{b:X}_");
+            }
+        }
+
+        fixedName = sb.ToString();
+        _symbolNameFixups[name] = fixedName;
+        return fixedName;
+    }
 }
index 74349b3..92784d8 100644 (file)
@@ -13,12 +13,16 @@ using Microsoft.Build.Utilities;
 
 internal sealed class PInvokeTableGenerator
 {
-    private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
     private readonly Dictionary<Assembly, bool> _assemblyDisableRuntimeMarshallingAttributeCache = new();
 
     private TaskLoggingHelper Log { get; set; }
+    private readonly Func<string, string> _fixupSymbolName;
 
-    public PInvokeTableGenerator(TaskLoggingHelper log) => Log = log;
+    public PInvokeTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
+    {
+        Log = log;
+        _fixupSymbolName = fixupSymbolName;
+    }
 
     public IEnumerable<string> Generate(string[] pinvokeModules, string[] assemblies, string outputPath)
     {
@@ -234,14 +238,14 @@ internal sealed class PInvokeTableGenerator
 
         foreach (var module in modules.Keys)
         {
-            string symbol = ModuleNameToId(module) + "_imports";
+            string symbol = _fixupSymbolName(module) + "_imports";
             w.WriteLine("static PinvokeImport " + symbol + " [] = {");
 
             var assemblies_pinvokes = pinvokes.
                 Where(l => l.Module == module && !l.Skip).
                 OrderBy(l => l.EntryPoint).
                 GroupBy(d => d.EntryPoint).
-                Select(l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " +
+                Select(l => "{\"" + _fixupSymbolName(l.Key) + "\", " + _fixupSymbolName(l.Key) + "}, " +
                                 "// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n)));
 
             foreach (var pinvoke in assemblies_pinvokes)
@@ -255,7 +259,7 @@ internal sealed class PInvokeTableGenerator
         w.Write("static void *pinvoke_tables[] = { ");
         foreach (var module in modules.Keys)
         {
-            string symbol = ModuleNameToId(module) + "_imports";
+            string symbol = _fixupSymbolName(module) + "_imports";
             w.Write(symbol + ",");
         }
         w.WriteLine("};");
@@ -266,18 +270,6 @@ internal sealed class PInvokeTableGenerator
         }
         w.WriteLine("};");
 
-        static string ModuleNameToId(string name)
-        {
-            if (name.IndexOfAny(s_charsToReplace) < 0)
-                return name;
-
-            string fixedName = name;
-            foreach (char c in s_charsToReplace)
-                fixedName = fixedName.Replace(c, '_');
-
-            return fixedName;
-        }
-
         static bool ShouldTreatAsVariadic(PInvoke[] candidates)
         {
             if (candidates.Length < 2)
@@ -295,35 +287,7 @@ internal sealed class PInvokeTableGenerator
         }
     }
 
-    private static string FixupSymbolName(string name)
-    {
-        UTF8Encoding utf8 = new();
-        byte[] bytes = utf8.GetBytes(name);
-        StringBuilder sb = new();
-
-        foreach (byte b in bytes)
-        {
-            if ((b >= (byte)'0' && b <= (byte)'9') ||
-                (b >= (byte)'a' && b <= (byte)'z') ||
-                (b >= (byte)'A' && b <= (byte)'Z') ||
-                (b == (byte)'_'))
-            {
-                sb.Append((char)b);
-            }
-            else if (s_charsToReplace.Contains((char)b))
-            {
-                sb.Append('_');
-            }
-            else
-            {
-                sb.Append($"_{b:X}_");
-            }
-        }
-
-        return sb.ToString();
-    }
-
-    private static string SymbolNameForMethod(MethodInfo method)
+    private string SymbolNameForMethod(MethodInfo method)
     {
         StringBuilder sb = new();
         Type? type = method.DeclaringType;
@@ -331,7 +295,7 @@ internal sealed class PInvokeTableGenerator
         sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_");
         sb.Append(method.Name);
 
-        return FixupSymbolName(sb.ToString());
+        return _fixupSymbolName(sb.ToString());
     }
 
     private static string MapType(Type t) => t.Name switch
@@ -374,7 +338,7 @@ internal sealed class PInvokeTableGenerator
         {
             // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
             // https://github.com/dotnet/runtime/issues/43791
-            sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
+            sb.Append($"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
             return sb.ToString();
         }
 
@@ -390,7 +354,7 @@ internal sealed class PInvokeTableGenerator
         }
 
         sb.Append(MapType(method.ReturnType));
-        sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} (");
+        sb.Append($" {_fixupSymbolName(pinvoke.EntryPoint)} (");
         int pindex = 0;
         var pars = method.GetParameters();
         foreach (var p in pars)
@@ -404,7 +368,7 @@ internal sealed class PInvokeTableGenerator
         return sb.ToString();
     }
 
-    private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
+    private void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
     {
         // Generate native->interp entry functions
         // These are called by native code, so they need to obtain
@@ -450,7 +414,7 @@ internal sealed class PInvokeTableGenerator
 
             bool is_void = method.ReturnType.Name == "Void";
 
-            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
+            string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
             uint token = (uint)method.MetadataToken;
             string class_name = method.DeclaringType.Name;
             string method_name = method.Name;
@@ -517,7 +481,7 @@ internal sealed class PInvokeTableGenerator
         foreach (var cb in callbacks)
         {
             var method = cb.Method;
-            string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
+            string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
             string class_name = method.DeclaringType.Name;
             string method_name = method.Name;
             w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\",");