[wasm] Implement pinvoke callback support. (#38932)
authorZoltan Varga <vargaz@gmail.com>
Fri, 10 Jul 2020 11:21:19 +0000 (07:21 -0400)
committerGitHub <noreply@github.com>
Fri, 10 Jul 2020 11:21:19 +0000 (07:21 -0400)
* [wasm] Implement pinvoke callback support.

* Allow void return types.

* Fix name mangling for assembly names.

* Handle MonoPInvokeCallbackAttribute as well for now.

src/mono/wasm/runtime/driver.c
src/mono/wasm/runtime/pinvoke.c
src/mono/wasm/runtime/pinvoke.h
tools-local/tasks/mobile.tasks/WasmAppBuilder/PInvokeTableGenerator.cs

index cafc032..f9d190e 100644 (file)
@@ -14,6 +14,7 @@
 #include <mono/utils/mono-logger.h>
 #include <mono/utils/mono-dl-fallback.h>
 #include <mono/jit/jit.h>
+#include <mono/jit/mono-private-unstable.h>
 
 #include "pinvoke.h"
 
@@ -37,6 +38,7 @@ char *monoeg_g_getenv(const char *variable);
 int monoeg_g_setenv(const char *variable, const char *value, int overwrite);
 void mono_free (void*);
 int32_t mini_parse_debug_option (const char *option);
+char *mono_method_get_full_name (MonoMethod *method);
 
 static MonoClass* datetime_class;
 static MonoClass* datetimeoffset_class;
@@ -293,6 +295,37 @@ icall_table_lookup_symbol (void *func)
 
 #endif
 
+/*
+ * get_native_to_interp:
+ *
+ *   Return a pointer to a wasm function which can be used to enter the interpreter to
+ * execute METHOD from native code.
+ * EXTRA_ARG is the argument passed to the interp entry functions in the runtime.
+ */
+void*
+get_native_to_interp (MonoMethod *method, void *extra_arg)
+{
+       uint32_t token = mono_method_get_token (method);
+       MonoClass *klass = mono_method_get_class (method);
+       MonoImage *image = mono_class_get_image (klass);
+       MonoAssembly *assembly = mono_image_get_assembly (image);
+       MonoAssemblyName *aname = mono_assembly_get_name (assembly);
+       const char *name = mono_assembly_name_get_name (aname);
+       char key [128];
+       int len;
+
+       assert (strlen (name) < 100);
+       sprintf (key, "%s_%d", name, token);
+       len = strlen (key);
+       for (int i = 0; i < len; ++i) {
+               if (key [i] == '.')
+                       key [i] = '_';
+       }
+
+       void *addr = wasm_dl_get_native_to_interp (key, extra_arg);
+       return addr;
+}
+
 void mono_initialize_internals ()
 {
        mono_add_internal_call ("Interop/Runtime::InvokeJS", mono_wasm_invoke_js);
@@ -328,6 +361,7 @@ mono_wasm_load_runtime (const char *managed_path, int enable_debugging)
        mini_parse_debug_option ("top-runtime-invoke-unhandled");
 
        mono_dl_fallback_register (wasm_dl_load, wasm_dl_symbol, NULL, NULL);
+       mono_wasm_install_get_native_to_interp_tramp (get_native_to_interp);
 
 #ifdef ENABLE_AOT
        // Defined in driver-gen.c
index 6b57893..3804005 100644 (file)
@@ -1,6 +1,3 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
 #include "pinvoke.h"
 
 #include <stdint.h>
@@ -11,6 +8,7 @@
  */
 #define NULL ((void*)0)
 int strcmp (const char *s1, const char *s2);
+void mono_wasm_printerr (const char *s);
 
 #ifdef GEN_PINVOKE
 #include "pinvoke-table.h"
@@ -18,6 +16,12 @@ int strcmp (const char *s1, const char *s2);
 #include "pinvoke-tables-default.h"
 #endif
 
+void
+mono_wasm_pinvoke_vararg_stub (void)
+{
+       /* This is just a stub used to mark vararg pinvokes */
+}
+
 void*
 wasm_dl_lookup_pinvoke_table (const char *name)
 {
@@ -38,3 +42,20 @@ wasm_dl_is_pinvoke_table (void *handle)
        }
        return 0;
 }
+
+void*
+wasm_dl_get_native_to_interp (const char *key, void *extra_arg)
+{
+#ifdef GEN_PINVOKE
+       for (int i = 0; i < sizeof (wasm_native_to_interp_map) / sizeof (void*); ++i) {
+               if (!strcmp (wasm_native_to_interp_map [i], key)) {
+                       void *addr = wasm_native_to_interp_funcs [i];
+                       wasm_native_to_interp_ftndescs [i] = *(InterpFtnDesc*)extra_arg;
+                       return addr;
+               }
+       }
+       return NULL;
+#else
+       return NULL;
+#endif
+}
index 790bcf0..03c5ce4 100644 (file)
@@ -6,10 +6,21 @@ typedef struct {
        void *func;
 } PinvokeImport;
 
+typedef struct {
+       void *func;
+       void *arg;
+} InterpFtnDesc;
+
 void*
 wasm_dl_lookup_pinvoke_table (const char *name);
 
 int
 wasm_dl_is_pinvoke_table (void *handle);
 
+void*
+wasm_dl_get_native_to_interp (const char *key, void *extra_arg);
+
+void
+mono_wasm_pinvoke_vararg_stub (void);
+
 #endif
index 14b66c4..0447fc8 100644 (file)
@@ -31,11 +31,12 @@ public class PInvokeTableGenerator : Task
 
     private void GenPInvokeTable(string[] pinvokeModules, string[] assemblies)
     {
-        var modules = new Dictionary<string, string> ();
+        var modules = new Dictionary<string, string>();
         foreach (var module in pinvokeModules)
             modules [module] = module;
 
         var pinvokes = new List<PInvoke>();
+        var callbacks = new List<PInvokeCallback>();
 
         var resolver = new PathAssemblyResolver(assemblies);
         var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib");
@@ -43,7 +44,7 @@ public class PInvokeTableGenerator : Task
         {
             var a = mlc.LoadFromAssemblyPath(aname);
             foreach (var type in a.GetTypes())
-                CollectPInvokes(pinvokes, type);
+                CollectPInvokes(pinvokes, callbacks, type);
         }
 
         Log.LogMessage(MessageImportance.Normal, $"Generating pinvoke table to '{OutputPath}'.");
@@ -51,18 +52,34 @@ public class PInvokeTableGenerator : Task
         using (var w = File.CreateText(OutputPath!))
         {
             EmitPInvokeTable(w, modules, pinvokes);
+            EmitNativeToInterp(w, callbacks);
         }
     }
 
-    private void CollectPInvokes(List<PInvoke> pinvokes, Type type)
+    private void CollectPInvokes(List<PInvoke> pinvokes, List<PInvokeCallback> callbacks, Type type)
     {
-        foreach (var method in type.GetMethods (BindingFlags.DeclaredOnly|BindingFlags.Public|BindingFlags.NonPublic|BindingFlags.Static|BindingFlags.Instance)) {
-            if ((method.Attributes & MethodAttributes.PinvokeImpl) == 0)
-                continue;
-            var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
-            var module = (string)dllimport.ConstructorArguments[0].Value!;
-            var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
-            pinvokes.Add(new PInvoke(entrypoint, module, method));
+        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly|BindingFlags.Public|BindingFlags.NonPublic|BindingFlags.Static|BindingFlags.Instance)) {
+            if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
+            {
+                var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
+                var module = (string)dllimport.ConstructorArguments[0].Value!;
+                var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
+                pinvokes.Add(new PInvoke(entrypoint, module, method));
+            }
+
+            foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
+            {
+                try
+                {
+                    if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
+                        cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
+                        callbacks.Add(new PInvokeCallback(method));
+                }
+                catch
+                {
+                    // Assembly not found, ignore
+                }
+            }
         }
     }
 
@@ -86,17 +103,17 @@ public class PInvokeTableGenerator : Task
                 if (pinvoke.Module == module)
                     w.WriteLine("{\"" + pinvoke.EntryPoint + "\", " + pinvoke.EntryPoint + "},");
             }
-            w.WriteLine ("{NULL, NULL}");
-            w.WriteLine ("};");
+            w.WriteLine("{NULL, NULL}");
+            w.WriteLine("};");
         }
-        w.Write ("static void *pinvoke_tables[] = { ");
+        w.Write("static void *pinvoke_tables[] = { ");
         foreach (var module in modules.Keys)
         {
             string symbol = module.Replace(".", "_") + "_imports";
             w.Write(symbol + ",");
         }
         w.WriteLine("};");
-        w.Write ("static char *pinvoke_names[] = { ");
+        w.Write("static char *pinvoke_names[] = { ");
         foreach (var module in modules.Keys)
         {
             w.Write("\"" + module + "\"" + ",");
@@ -125,19 +142,148 @@ public class PInvokeTableGenerator : Task
     {
         var sb = new StringBuilder();
         var method = pinvoke.Method;
-        sb.Append(MapType (method.ReturnType));
+        sb.Append(MapType(method.ReturnType));
         sb.Append($" {pinvoke.EntryPoint} (");
         int pindex = 0;
-        var pars = method.GetParameters ();
+        var pars = method.GetParameters();
         foreach (var p in pars) {
             if (pindex > 0)
                 sb.Append(",");
-            sb.Append(MapType (pars[pindex].ParameterType));
+            sb.Append(MapType(pars[pindex].ParameterType));
             pindex ++;
         }
         sb.Append(");");
         return sb.ToString();
     }
+
+       void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
+    {
+               // Generate native->interp entry functions
+               // These are called by native code, so they need to obtain
+               // the interp entry function/arg from a global array
+               // They also need to have a signature matching what the
+               // native code expects, which is the native signature
+               // of the delegate invoke in the [MonoPInvokeCallback]
+               // attribute.
+               // Only blittable parameter/return types are supposed.
+               int cb_index = 0;
+
+               // Arguments to interp entry functions in the runtime
+               w.WriteLine("InterpFtnDesc wasm_native_to_interp_ftndescs[" + callbacks.Count + "];");
+
+               foreach (var cb in callbacks) {
+                       var method = cb.Method;
+
+                       if (method.ReturnType != typeof(void) && !IsBlittable(method.ReturnType))
+                               Error("The return type of pinvoke callback method '" + method + "' needs to be blittable.");
+                       foreach (var p in method.GetParameters()) {
+                               if (!IsBlittable(p.ParameterType))
+                                       Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
+                       }
+               }
+
+               foreach (var cb in callbacks) {
+                       var sb = new StringBuilder();
+                       var method = cb.Method;
+
+                       // The signature of the interp entry function
+                       // This is a gsharedvt_in signature
+                       sb.Append("typedef void ");
+                       sb.Append(" (*WasmInterpEntrySig_" + cb_index + ") (");
+                       int pindex = 0;
+                       if (method.ReturnType.Name != "Void") {
+                               sb.Append("int");
+                               pindex ++;
+                       }
+                       foreach (var p in method.GetParameters()) {
+                               if (pindex > 0)
+                                       sb.Append(",");
+                               sb.Append("int");
+                               pindex ++;
+                       }
+                       if (pindex > 0)
+                               sb.Append(",");
+                       // Extra arg
+                       sb.Append("int");
+                       sb.Append(");\n");
+
+                       bool is_void = method.ReturnType.Name == "Void";
+
+                       string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
+                       uint token = (uint)method.MetadataToken;
+                       string entry_name = $"wasm_native_to_interp_{module_symbol}_{token}";
+            cb.EntryName = entry_name;
+                       sb.Append(MapType(method.ReturnType));
+                       sb.Append($" {entry_name} (");
+                       pindex = 0;
+                       foreach (var p in method.GetParameters()) {
+                               if (pindex > 0)
+                                       sb.Append(",");
+                               sb.Append(MapType(method.GetParameters()[pindex].ParameterType));
+                               sb.Append(" arg" + pindex);
+                               pindex ++;
+                       }
+                       sb.Append(") { \n");
+                       if (!is_void)
+                               sb.Append(MapType(method.ReturnType) + " res;\n");
+                       sb.Append("((WasmInterpEntrySig_" + cb_index + ")wasm_native_to_interp_ftndescs [" + cb_index + "].func) (");
+                       pindex = 0;
+                       if (!is_void) {
+                               sb.Append("&res");
+                               pindex ++;
+                       }
+                       int aindex = 0;
+                       foreach (var p in method.GetParameters()) {
+                               if (pindex > 0)
+                                       sb.Append(", ");
+                               sb.Append("&arg" + aindex);
+                               pindex ++;
+                               aindex ++;
+                       }
+                       if (pindex > 0)
+                               sb.Append(", ");
+                       sb.Append($"wasm_native_to_interp_ftndescs [{cb_index}].arg");
+                       sb.Append(");\n");
+                       if (!is_void)
+                               sb.Append("return res;\n");
+                       sb.Append("}");
+                       w.WriteLine(sb);
+                       cb_index ++;
+               }
+
+               // Array of function pointers
+               w.Write ("static void *wasm_native_to_interp_funcs[] = { ");
+               foreach (var cb in callbacks) {
+                       w.Write (cb.EntryName + ",");
+               }
+               w.WriteLine ("};");
+
+               // Lookup table from method->interp entry
+               // The key is a string of the form <assembly name>_<method token>
+               // FIXME: Use a better encoding
+               w.Write ("static const char *wasm_native_to_interp_map[] = { ");
+               foreach (var cb in callbacks) {
+                       var method = cb.Method;
+                       string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
+                       uint token = (uint)method.MetadataToken;
+                       w.WriteLine ($"\"{module_symbol}_{token}\",");
+               }
+               w.WriteLine ("};");
+       }
+
+    static bool IsBlittable (Type type)
+    {
+        if (type.IsPrimitive || type.IsByRef || type.IsPointer)
+            return true;
+        else
+            return false;
+    }
+
+    static void Error (string msg)
+    {
+        // FIXME:
+        throw new Exception(msg);
+    }
 }
 
 class PInvoke
@@ -153,3 +299,14 @@ class PInvoke
     public string Module;
     public MethodInfo Method;
 }
+
+class PInvokeCallback
+{
+    public PInvokeCallback(MethodInfo method)
+    {
+        Method = method;
+    }
+
+    public MethodInfo Method;
+    public string? EntryName;
+}