[mono] Introduce designated direct pinvokes to mono aot compiler (#79721)
authorMitchell Hwang <16830051+mdh1418@users.noreply.github.com>
Thu, 2 Feb 2023 08:18:56 +0000 (03:18 -0500)
committerGitHub <noreply@github.com>
Thu, 2 Feb 2023 08:18:56 +0000 (09:18 +0100)
Introduce designated direct pinvokes to mono aot compiler.

src/mono/mono/mini/aot-compiler.c
src/tasks/AotCompilerTask/MonoAOTCompiler.cs

index 2af1c13..9e59274 100644 (file)
@@ -208,6 +208,8 @@ typedef struct MonoAotOptions {
        gboolean gen_msym_dir;
        char *gen_msym_dir_path;
        gboolean direct_pinvoke;
+       GList *direct_pinvokes;
+       GList *direct_pinvoke_lists;
        gboolean direct_icalls;
        gboolean direct_extern_calls;
        gboolean no_direct_calls;
@@ -307,6 +309,7 @@ typedef struct MonoAotCompile {
        GHashTable *method_to_cfg;
        GHashTable *token_info_hash;
        GHashTable *method_to_pinvoke_import;
+       GHashTable *direct_pinvokes;
        GHashTable *method_to_external_icall_symbol_name;
        GPtrArray *extra_methods;
        GPtrArray *image_table;
@@ -6144,6 +6147,122 @@ method_is_externally_callable (MonoAotCompile *acfg, MonoMethod *method)
        }
 }
 
+#ifdef MONO_ARCH_AOT_SUPPORTED
+//---------------------------------------------------------------------------------------
+//
+// get_pinvoke_import:
+//
+// Returns whether or not module and entrypoint pinvoke information could be grabbed
+// from the MonoMethod. It populates module and entrypoint if they are not NULL. A
+// hash table is populated with key value pairs corresponding to the MonoMethod and
+// module entrypoint string array to serve as a fast path cache. The module and
+// entrypoint string data are owned by the hash table.
+//
+// Arguments:
+//  * acfg - the MonoAotCompiler instance
+//  * method - the MonoMethod to grab pinvoke scope and import information from
+//  ** module - the pointer to the module name string (owned by the hashtable)
+//  ** entrypoint - the pointer to the entrypoint name string (owned by the hashtable)
+//
+// Return Value:
+//  gboolean corresponding to whether or not module and entrypoint pinvoke information
+//  could be grabbed from the provided MonoMethod.
+//
+
+static gboolean
+get_pinvoke_import (MonoAotCompile *acfg, MonoMethod *method, const char **module, const char **entrypoint)
+{
+       MonoImage *image = m_class_get_image (method->klass);
+       MonoMethodPInvoke *piinfo = (MonoMethodPInvoke *) method;
+       MonoTableInfo *tables = image->tables;
+       MonoTableInfo *im = &tables [MONO_TABLE_IMPLMAP];
+       MonoTableInfo *mr = &tables [MONO_TABLE_MODULEREF];
+       guint32 im_cols [MONO_IMPLMAP_SIZE];
+       int module_idx;
+       char **scope_import;
+       guint32 scope_token;
+       char *module_ref_basename;
+       char *module_ref_basename_extension;
+
+       if (g_hash_table_lookup_extended (acfg->method_to_pinvoke_import, method, NULL, (gpointer *)&scope_import) && scope_import) {
+               if (module)
+                       *module = scope_import[0];
+               if (entrypoint)
+                       *entrypoint = scope_import[1];
+               return TRUE;
+       }
+
+       if (piinfo->implmap_idx == 0 || mono_metadata_table_bounds_check (image, MONO_TABLE_IMPLMAP, piinfo->implmap_idx))
+               return FALSE;
+
+       mono_metadata_decode_row (im, piinfo->implmap_idx - 1, im_cols, MONO_IMPLMAP_SIZE);
+
+       module_idx = im_cols [MONO_IMPLMAP_SCOPE];
+       if (module_idx == 0 || mono_metadata_table_bounds_check (image, MONO_TABLE_MODULEREF, module_idx))
+               return FALSE;
+
+       scope_import = (char **) g_malloc0 (2 * sizeof (char *));
+       scope_token = mono_metadata_decode_row_col (mr, im_cols [MONO_IMPLMAP_SCOPE] - 1, MONO_MODULEREF_NAME);
+       module_ref_basename = g_path_get_basename (mono_metadata_string_heap (image, scope_token));
+       module_ref_basename_extension = strrchr (module_ref_basename, '.');
+       if (module_ref_basename_extension) {
+               const char **suffixes = mono_dl_get_so_suffixes ();
+               for (int i = 0; suffixes [i] && suffixes [i][0] != '\0'; i++) {
+                       if (!strcmp (module_ref_basename_extension, suffixes [i])) {
+                               *module_ref_basename_extension= '\0';
+                               break;
+                       }
+               }
+       }
+
+       scope_import [0] = module_ref_basename;
+       scope_import [1] = g_strdup_printf ("%s", mono_metadata_string_heap (image, im_cols [MONO_IMPLMAP_NAME]));
+
+       g_hash_table_insert (acfg->method_to_pinvoke_import, method, scope_import);
+
+       if (module)
+               *module = scope_import [0];
+       if (entrypoint)
+               *entrypoint = scope_import [1];
+
+       return TRUE;
+}
+#else
+static gboolean
+get_pinvoke_import (MonoAotCompile *acfg, MonoMethod *method, const char **module, const char **entrypoint)
+{
+       return FALSE;
+}
+#endif
+
+/*
+ * is_direct_pinvoke_specified_for_method:
+ *
+ * Returns whether the method is specified to be directly pinvoked based on
+ * the direct_pinvoke HashTable populated in process_specified_direct_pinvokes.
+ */
+static gboolean
+is_direct_pinvoke_specified_for_method (MonoAotCompile *acfg, MonoMethod *method)
+{
+       const char *module_name, *sym = NULL;
+       GHashTable *val;
+
+       if (acfg->aot_opts.direct_pinvoke)
+               return TRUE;
+
+       if (!acfg->aot_opts.direct_pinvokes && !acfg->aot_opts.direct_pinvoke_lists)
+               return FALSE;
+
+       if (get_pinvoke_import (acfg, method, &module_name, &sym) && g_hash_table_lookup_extended (acfg->direct_pinvokes, module_name, NULL, (gpointer *)&val)) {
+               if (!val)
+                       return TRUE;
+
+               return g_hash_table_contains (val, sym);
+       }
+
+       return FALSE;
+}
+
 /*
  * is_direct_callable:
  *
@@ -6200,8 +6319,7 @@ is_direct_callable (MonoAotCompile *acfg, MonoMethod *method, MonoJumpInfo *patc
                /* Cross assembly calls */
                return method_is_externally_callable (acfg, patch_info->data.method);
        } else if ((patch_info->type == MONO_PATCH_INFO_ICALL_ADDR_CALL && patch_info->data.method->flags & METHOD_ATTRIBUTE_PINVOKE_IMPL)) {
-               if (acfg->aot_opts.direct_pinvoke)
-                       return TRUE;
+               return is_direct_pinvoke_specified_for_method (acfg, patch_info->data.method);
        } else if (patch_info->type == MONO_PATCH_INFO_ICALL_ADDR_CALL) {
                if (acfg->aot_opts.direct_icalls)
                        return TRUE;
@@ -6211,44 +6329,6 @@ is_direct_callable (MonoAotCompile *acfg, MonoMethod *method, MonoJumpInfo *patc
        return FALSE;
 }
 
-#ifdef MONO_ARCH_AOT_SUPPORTED
-static const char *
-get_pinvoke_import (MonoAotCompile *acfg, MonoMethod *method)
-{
-       MonoImage *image = m_class_get_image (method->klass);
-       MonoMethodPInvoke *piinfo = (MonoMethodPInvoke *) method;
-       MonoTableInfo *tables = image->tables;
-       MonoTableInfo *im = &tables [MONO_TABLE_IMPLMAP];
-       guint32 im_cols [MONO_IMPLMAP_SIZE];
-       char *import;
-
-       import = (char *)g_hash_table_lookup (acfg->method_to_pinvoke_import, method);
-       if (import != NULL)
-               return import;
-
-       if (piinfo->implmap_idx == 0 || mono_metadata_table_bounds_check (image, MONO_TABLE_IMPLMAP, piinfo->implmap_idx))
-               return NULL;
-
-       mono_metadata_decode_row (im, piinfo->implmap_idx - 1, im_cols, MONO_IMPLMAP_SIZE);
-
-       int module_idx = im_cols [MONO_IMPLMAP_SCOPE];
-       if (module_idx == 0 || mono_metadata_table_bounds_check (image, MONO_TABLE_MODULEREF, module_idx))
-               return NULL;
-
-       import = g_strdup_printf ("%s", mono_metadata_string_heap (image, im_cols [MONO_IMPLMAP_NAME]));
-
-       g_hash_table_insert (acfg->method_to_pinvoke_import, method, import);
-
-       return import;
-}
-#else
-static const char *
-get_pinvoke_import (MonoAotCompile *acfg, MonoMethod *method)
-{
-       return NULL;
-}
-#endif
-
 static gint
 compare_lne (MonoDebugLineNumberEntry *a, MonoDebugLineNumberEntry *b)
 {
@@ -6398,7 +6478,7 @@ emit_and_reloc_code (MonoAotCompile *acfg, MonoMethod *method, guint8 *code, gui
        gboolean direct_call, external_call;
        guint32 got_slot;
        const char *direct_call_target = 0;
-       const char *direct_pinvoke;
+       const char *direct_pinvoke = NULL;
 #endif
 
        if (acfg->gas_line_numbers && method && debug_info) {
@@ -6528,7 +6608,7 @@ emit_and_reloc_code (MonoAotCompile *acfg, MonoMethod *method, guint8 *code, gui
                                                if (!(patch_info->data.method->flags & METHOD_ATTRIBUTE_PINVOKE_IMPL))
                                                        direct_pinvoke = lookup_icall_symbol_name_aot (patch_info->data.method);
                                                else
-                                                       direct_pinvoke = get_pinvoke_import (acfg, patch_info->data.method);
+                                                       get_pinvoke_import (acfg, patch_info->data.method, NULL, &direct_pinvoke);
                                                if (direct_pinvoke && !never_direct_pinvoke (direct_pinvoke)) {
                                                        direct_call = TRUE;
                                                        g_assert (strlen (direct_pinvoke) < 1000);
@@ -8534,6 +8614,24 @@ mono_aot_parse_options (const char *aot_options, MonoAotOptions *opts)
                        mini_debug_options.no_seq_points_compact_data = FALSE;
                        opts->gen_msym_dir = TRUE;
                        opts->gen_msym_dir_path = g_strdup (arg + strlen ("msym_dir="));
+               } else if (str_begins_with (arg, "direct-pinvokes=")) {
+                       char *direct_pinvokes = g_strdup (arg + strlen ("direct-pinvokes="));
+                       gchar *direct_pinvoke_ctx = NULL;
+                       gchar *direct_pinvoke = strtok_r (direct_pinvokes, ";", &direct_pinvoke_ctx);
+                       while (direct_pinvoke) {
+                               opts->direct_pinvokes = g_list_append (opts->direct_pinvokes, g_strdup (direct_pinvoke));
+                               direct_pinvoke = strtok_r (NULL, ";", &direct_pinvoke_ctx);
+                       }
+                       g_free (direct_pinvokes);
+               } else if (str_begins_with (arg, "direct-pinvoke-lists=")) {
+                       char *direct_pinvoke_lists = g_strdup (arg + strlen ("direct-pinvoke-lists="));
+                       gchar *direct_pinvoke_list_ctx = NULL;
+                       gchar *direct_pinvoke_list = strtok_r (direct_pinvoke_lists, ";", &direct_pinvoke_list_ctx);
+                       while (direct_pinvoke_list) {
+                               opts->direct_pinvoke_lists = g_list_append (opts->direct_pinvoke_lists, g_strdup (direct_pinvoke_list));
+                               direct_pinvoke_list = strtok_r (NULL, ";", &direct_pinvoke_list_ctx);
+                       }
+                       g_free (direct_pinvoke_lists);
                } else if (str_begins_with (arg, "direct-pinvoke")) {
                        opts->direct_pinvoke = TRUE;
                } else if (str_begins_with (arg, "direct-icalls")) {
@@ -8645,56 +8743,59 @@ mono_aot_parse_options (const char *aot_options, MonoAotOptions *opts)
                        opts->depfile = g_strdup (arg + strlen ("depfile="));
                } else if (str_begins_with (arg, "help") || str_begins_with (arg, "?")) {
                        printf ("Supported options for --aot:\n");
-                       printf ("    asmonly\n");
-                       printf ("    bind-to-runtime-version\n");
-                       printf ("    bitcode\n");
-                       printf ("    data-outfile=\n");
-                       printf ("    direct-icalls\n");
-                       printf ("    direct-pinvoke\n");
-                       printf ("    dwarfdebug\n");
-                       printf ("    full\n");
-                       printf ("    hybrid\n");
-                       printf ("    info\n");
-                       printf ("    keep-temps\n");
-                       printf ("    llvm\n");
-                       printf ("    llvmonly\n");
-                       printf ("    llvm-outfile=\n");
-                       printf ("    llvm-path=\n");
-                       printf ("    msym-dir=\n");
-                       printf ("    mtriple\n");
-                       printf ("    nimt-trampolines=\n");
-                       printf ("    nodebug\n");
-                       printf ("    no-direct-calls\n");
-                       printf ("    no-write-symbols\n");
-                       printf ("    nrgctx-trampolines=\n");
-                       printf ("    nrgctx-fetch-trampolines=\n");
-                       printf ("    ngsharedvt-trampolines=\n");
-                       printf ("    nftnptr-arg-trampolines=\n");
-                       printf ("    nunbox-arbitrary-trampolines=\n");
-                       printf ("    ntrampolines=\n");
-                       printf ("    outfile=\n");
-                       printf ("    profile=\n");
-                       printf ("    profile-only\n");
-                       printf ("    print-skipped-methods\n");
-                       printf ("    readonly-value=\n");
-                       printf ("    save-temps\n");
-                       printf ("    soft-debug\n");
-                       printf ("    static\n");
-                       printf ("    stats\n");
-                       printf ("    temp-path=\n");
-                       printf ("    tool-prefix=\n");
-                       printf ("    threads=\n");
-                       printf ("    write-symbols\n");
-                       printf ("    verbose\n");
-                       printf ("    allow-errors\n");
-                       printf ("    no-opt\n");
-                       printf ("    llvmopts=\n");
-                       printf ("    llvmllc=\n");
-                       printf ("    clangxx=\n");
-                       printf ("    depfile=\n");
-                       printf ("    mcpu=\n");
-                       printf ("    mattr=\n");
-                       printf ("    help/?\n");
+                       printf ("    asmonly                              - \n");
+                       printf ("    bind-to-runtime-version              - \n");
+                       printf ("    bitcode                              - \n");
+                       printf ("    data-outfile=<string>                - \n");
+                       printf ("    direct-icalls                        - \n");
+                       printf ("    direct-pinvokes=<string>             - Specific direct pinvokes to generate direct calls for an entire 'module' or specific 'module!entrypoint' separated by semi-colons. Incompatible with 'direct-pinvoke' option.\n");
+                       printf ("    direct-pinvoke-lists=<string>        - Files containing specific direct pinvokes to generate direct calls for an entire 'module' or specific 'module!entrypoint' on separate lines. Incompatible with 'direct-pinvoke' option.\n");
+                       printf ("    direct-pinvoke                       - Generate direct calls for all direct pinvokes encountered in the managed assembly.\n");
+                       printf ("    dwarfdebug                           - \n");
+                       printf ("    full                                 - \n");
+                       printf ("    hybrid                               - \n");
+                       printf ("    info                                 - \n");
+                       printf ("    keep-temps                           - \n");
+                       printf ("    llvm                                 - \n");
+                       printf ("    llvmonly                             - \n");
+                       printf ("    llvm-outfile=<string>                - \n");
+                       printf ("    llvm-path=<string>                   - \n");
+                       printf ("    msym-dir=<string>                    - \n");
+                       printf ("    mtriple                              - \n");
+                       printf ("    nimt-trampolines=<value>             - \n");
+                       printf ("    nodebug                              - \n");
+                       printf ("    no-direct-calls                      - \n");
+                       printf ("    no-write-symbols                     - \n");
+                       printf ("    nrgctx-trampolines=<value>           - \n");
+                       printf ("    nrgctx-fetch-trampolines=<value>     - \n");
+                       printf ("    ngsharedvt-trampolines=<value>       - \n");
+                       printf ("    nftnptr-arg-trampolines=<value>      - \n");
+                       printf ("    nunbox-arbitrary-trampolines=<value> - \n");
+                       printf ("    ntrampolines=<value>                 - \n");
+                       printf ("    outfile=<string>                     - \n");
+                       printf ("    profile=<string>                     - \n");
+                       printf ("    profile-only                         - \n");
+                       printf ("    mibc-profile=<string>                - \n");
+                       printf ("    print-skipped-methods                - \n");
+                       printf ("    readonly-value=<value>               - \n");
+                       printf ("    save-temps                           - \n");
+                       printf ("    soft-debug                           - \n");
+                       printf ("    static                               - \n");
+                       printf ("    stats                                - \n");
+                       printf ("    temp-path=<string>                   - \n");
+                       printf ("    tool-prefix=<value>                  - \n");
+                       printf ("    threads=<value>                      - \n");
+                       printf ("    write-symbols                        - \n");
+                       printf ("    verbose                              - \n");
+                       printf ("    allow-errors                         - \n");
+                       printf ("    no-opt                               - \n");
+                       printf ("    llvmopts=<value>                     - \n");
+                       printf ("    llvmllc=<value>                      - \n");
+                       printf ("    clangxx=<value>                      - \n");
+                       printf ("    depfile=<value>                      - \n");
+                       printf ("    mcpu=<value>                         - \n");
+                       printf ("    mattr=<value>                        - \n");
+                       printf ("    help/?                                 \n");
                        exit (0);
                } else {
                        fprintf (stderr, "AOT : Unknown argument '%s'.\n", arg);
@@ -9059,6 +9160,12 @@ add_referenced_patch (MonoAotCompile *acfg, MonoJumpInfo *patch_info, int depth)
        }
 }
 
+static inline gboolean
+is_direct_pinvoke_enabled (const MonoAotCompile *acfg)
+{
+       return acfg->aot_opts.direct_pinvoke || acfg->aot_opts.direct_pinvokes || acfg->aot_opts.direct_pinvoke_lists;
+}
+
 /*
  * compile_method:
  *
@@ -9161,7 +9268,7 @@ compile_method (MonoAotCompile *acfg, MonoMethod *method)
                flags = (JitFlags)(flags | JIT_FLAG_LLVM_ONLY | JIT_FLAG_EXPLICIT_NULL_CHECKS);
        if (acfg->aot_opts.no_direct_calls)
                flags = (JitFlags)(flags | JIT_FLAG_NO_DIRECT_ICALLS);
-       if (acfg->aot_opts.direct_pinvoke)
+       if (is_direct_pinvoke_enabled (acfg))
                flags = (JitFlags)(flags | JIT_FLAG_DIRECT_PINVOKE);
        if (acfg->aot_opts.interp)
                flags = (JitFlags)(flags | JIT_FLAG_INTERP);
@@ -10180,8 +10287,8 @@ mono_aot_get_direct_call_symbol (MonoJumpInfoType type, gconstpointer data)
                        MonoMethod *method = (MonoMethod *)data;
                        if (!(method->flags & METHOD_ATTRIBUTE_PINVOKE_IMPL))
                                sym = lookup_icall_symbol_name_aot (method);
-                       else if (llvm_acfg->aot_opts.direct_pinvoke)
-                               sym = get_pinvoke_import (llvm_acfg, method);
+                       else if (is_direct_pinvoke_specified_for_method (llvm_acfg, method))
+                               get_pinvoke_import (llvm_acfg, method, NULL, &sym);
                } else if (type == MONO_PATCH_INFO_JIT_ICALL_ID) {
                        MonoJitICallInfo const * const info = mono_find_jit_icall_info ((MonoJitICallId)(gsize)data);
                        char const * const name = info->c_symbol;
@@ -13706,6 +13813,17 @@ add_mibc_profile_methods (MonoAotCompile *acfg, char *filename)
 }
 
 static void
+free_method_pinvoke_import_value (gpointer data)
+{
+       gchar **value = (gchar **)data;
+       if (!value)
+               return;
+       g_free (value[0]);
+       g_free (value[1]);
+       g_free (value);
+}
+
+static void
 init_got_info (GotInfo *info)
 {
        int i;
@@ -13731,7 +13849,8 @@ acfg_create (MonoAssembly *ass, guint32 jit_opts)
        acfg->patch_to_plt_entry = g_new0 (GHashTable*, MONO_PATCH_INFO_NUM);
        acfg->method_to_cfg = g_hash_table_new (NULL, NULL);
        acfg->token_info_hash = g_hash_table_new_full (NULL, NULL, NULL, NULL);
-       acfg->method_to_pinvoke_import = g_hash_table_new_full (NULL, NULL, NULL, g_free);
+       acfg->method_to_pinvoke_import = g_hash_table_new_full (NULL, NULL, NULL, (GDestroyNotify)free_method_pinvoke_import_value);
+       acfg->direct_pinvokes = g_hash_table_new_full (g_str_hash, g_str_equal, g_free, (GDestroyNotify)g_hash_table_destroy);
        acfg->method_to_external_icall_symbol_name = g_hash_table_new_full (NULL, NULL, NULL, g_free);
        acfg->image_hash = g_hash_table_new (NULL, NULL);
        acfg->image_table = g_ptr_array_new ();
@@ -13776,6 +13895,41 @@ got_info_free (GotInfo *info)
 }
 
 static void
+aot_opts_free (MonoAotOptions *aot_opts)
+{
+       g_free (aot_opts->outfile);
+       g_free (aot_opts->llvm_outfile);
+       g_free (aot_opts->data_outfile);
+       for (GList *elem = aot_opts->profile_files; elem; elem = elem->next)
+               g_free (elem->data);
+       g_list_free (aot_opts->profile_files);
+       for (GList *elem = aot_opts->mibc_profile_files; elem; elem = elem->next)
+               g_free (elem->data);
+       g_list_free (aot_opts->mibc_profile_files);
+       g_free (aot_opts->gen_msym_dir_path);
+       for (GList *elem = aot_opts->direct_pinvokes; elem; elem = elem->next)
+               g_free (elem->data);
+       g_list_free (aot_opts->direct_pinvokes);
+       for (GList *elem = aot_opts->direct_pinvoke_lists; elem; elem = elem->next)
+               g_free (elem->data);
+       g_list_free (aot_opts->direct_pinvoke_lists);
+       g_free (aot_opts->dedup_include);
+       g_free (aot_opts->tool_prefix);
+       g_free (aot_opts->ld_flags);
+       g_free (aot_opts->ld_name);
+       g_free (aot_opts->mtriple);
+       g_free (aot_opts->llvm_path);
+       g_free (aot_opts->temp_path);
+       g_free (aot_opts->instances_logfile_path);
+       g_free (aot_opts->logfile);
+       g_free (aot_opts->llvm_opts);
+       g_free (aot_opts->llvm_llc);
+       g_free (aot_opts->llvm_cpu_attr);
+       g_free (aot_opts->clangxx);
+       g_free (aot_opts->depfile);
+}
+
+static void
 acfg_free (MonoAotCompile *acfg)
 {
 #ifdef ENABLE_LLVM
@@ -13807,6 +13961,7 @@ acfg_free (MonoAotCompile *acfg)
        g_hash_table_destroy (acfg->method_to_cfg);
        g_hash_table_destroy (acfg->token_info_hash);
        g_hash_table_destroy (acfg->method_to_pinvoke_import);
+       g_hash_table_destroy (acfg->direct_pinvokes);
        g_hash_table_destroy (acfg->method_to_external_icall_symbol_name);
        g_hash_table_destroy (acfg->image_hash);
        g_hash_table_destroy (acfg->unwind_info_offsets);
@@ -14156,6 +14311,200 @@ init_options (MonoAotOptions *aot_opts)
        aot_opts->clangxx = g_strdup ("clang++");
 }
 
+//---------------------------------------------------------------------------------------
+//
+// is_direct_pinvoke_parsable checks whether the direct pinvoke should be parsed into
+// the corresponding module and entrypoint names. Empty lines and comments are skipped.
+//
+// Arguments:
+//  * direct_pinvoke - the string corresponding to the direct pinvoke to parse
+//
+// Return Value:
+//  gboolean pertaining to whether or not the direct pinvoke should be parsed.
+//  into the respective module_name and entrypoint_name.
+//
+
+static gboolean
+is_direct_pinvoke_parsable (const char *direct_pinvoke)
+{
+       if (!direct_pinvoke)
+               return FALSE;
+
+       if (direct_pinvoke[0] == '\0')
+               return FALSE;
+
+       if (direct_pinvoke[0] == '#')
+               return FALSE;
+
+       return TRUE;
+}
+
+//---------------------------------------------------------------------------------------
+//
+// parsed_direct_pinvoke parses the direct pinvoke for the module_name and entrypoint_name
+// It presumes that is_direct_pinvoke_parsable (direct_pinvoke) returned true.
+//
+// Parsing:
+// The direct pinvoke is stripped of any leading or trailing whitespace.
+// The remaining is assumed to be of the form MODULE or MODULE!ENTRYPOINT.
+// The passed in pointers are then reassigned to the corresponding strings.
+//
+// Arguments:
+//  * acfg - the MonoAotCompiler instance
+//  * direct pinvoke - the string corresponding to the direct pinvoke to parse
+//  ** module_name_ptr - the pointer to the module name string
+//  ** entrypoint_name_ptr - the pointer to the entrypoint name string
+//
+// Return Value:
+//  enum pertaining to whether or not the direct pinvoke entry should be skipped, was
+//  successfully split into the respective module_name and entrypoint_name, or if
+//  an error has occurred.
+//
+
+static gboolean
+parsed_direct_pinvoke (MonoAotCompile *acfg, const char *direct_pinvoke, char **module_name_ptr, char **entrypoint_name_ptr)
+{
+       gboolean parsed = FALSE;
+
+       *module_name_ptr = NULL;
+       *entrypoint_name_ptr = NULL;
+
+       char **direct_pinvoke_split = g_strsplit (direct_pinvoke, "!", 2);
+       if (direct_pinvoke_split) {
+               *module_name_ptr = g_strdup (direct_pinvoke_split [0]);
+               *entrypoint_name_ptr = g_strdup (direct_pinvoke_split [1]);
+
+               // ENTRYPOINT can be NULL if Direct PInvoke is just MODULE
+               if (*module_name_ptr && (!direct_pinvoke_split [1] || *entrypoint_name_ptr))
+                       parsed = TRUE;
+
+               g_strfreev (direct_pinvoke_split);
+       }
+
+       if (!parsed) {
+               aot_printerrf (acfg, "Failed to parse the specified direct pinvoke '%s'. 'g_strsplit' or 'g_strdup' failed, possible due to insufficient memory.\n", direct_pinvoke);
+               g_free (*module_name_ptr);
+               *module_name_ptr = NULL;
+
+               g_free (*entrypoint_name_ptr);
+               *entrypoint_name_ptr = NULL;
+       }
+
+       return parsed;
+}
+
+//---------------------------------------------------------------------------------------
+//
+// add_direct_pinvoke adds the module and entrypoint of a specified direct pinvoke to the
+// MonoAotCompile instance's HashTable of module/entrypoint entries.
+// It is presumed that module_name_ptr and entrypoint_name_ptr point to valid strings.
+// It transfers ownership of the module_name and entrypoint_name strings data to the HashTable
+// This function takes ownership of the module and entrypoint strings pointers regardless of
+// success or failure (in which it frees the memory) and sets them to NULL.
+//
+// Arguments:
+//  * acfg - the MonoAotCompiler instance
+//  ** module_name_ptr - the pointer to the module name (assumed not NULL)
+//  ** entrypoint_name_ptr - the pointer to the entrypoint name
+//
+// Return Value:
+//  gboolean corresponding to whether or not the direct_pinvoke was successfully added
+//  or if an error occurred.
+//
+
+static gboolean
+add_direct_pinvoke (MonoAotCompile *acfg, char **module_name_ptr, char **entrypoint_name_ptr)
+{
+       gboolean success = TRUE;
+       GHashTable *entrypoints;
+       // If there is an entry for the module
+       if (g_hash_table_lookup_extended (acfg->direct_pinvokes, *module_name_ptr, NULL, (gpointer *)&entrypoints)) {
+               // Not all entrypoints are direct, if specifying a new entrypoint
+               if (entrypoints && *entrypoint_name_ptr && !g_hash_table_contains (entrypoints, *entrypoint_name_ptr)) {
+                       g_hash_table_insert (entrypoints, *entrypoint_name_ptr, NULL);
+                       *entrypoint_name_ptr = NULL;
+               }
+       // New entry for module, All entrypoints are direct
+       } else if (!*entrypoint_name_ptr) {
+               g_hash_table_insert (acfg->direct_pinvokes, *module_name_ptr, NULL);
+               *module_name_ptr = NULL;
+       // New entry for module, specifying an entrypoint
+       } else {
+               entrypoints = g_hash_table_new_full (g_str_hash, g_str_equal, g_free, NULL);
+               if (!entrypoints) {
+                       aot_printerrf (acfg, "Failed to allocate new entrypoints HashTable.\n");
+                       success = FALSE;
+               } else {
+                       g_hash_table_insert (entrypoints, *entrypoint_name_ptr, NULL);
+                       g_hash_table_insert (acfg->direct_pinvokes, *module_name_ptr, entrypoints);
+                       *entrypoint_name_ptr = NULL;
+                       *module_name_ptr = NULL;
+               }
+       }
+
+       g_free (*module_name_ptr);
+       g_free (*entrypoint_name_ptr);
+
+       *module_name_ptr = NULL;
+       *entrypoint_name_ptr = NULL;
+       return success;
+}
+
+//---------------------------------------------------------------------------------------
+//
+// process_specified_direct_pinvokes processes the direct pinvokes and direct pinvoke lists
+// the user specifies in the direct-pinvokes and direct-pinvoke-lists options and adds the
+// entire module or module and set of entrypoints to the MonoAotCompile instance's
+// direct_pinvoke HashTable.
+//
+// Format of direct_pinvoke HashTable:
+// The direct_pinvoke HashTable keys are module names, and its values are HashTables
+// corresponding to entrypoint names within the module to be direct pinvoked.
+// A NULL value in the direct_pinvoke HashTable is understood to mean that all entrypoints
+// from the library are direct. It will overrule previously added HashTable of entrypoints,
+// and it will prevent new HashTable of entrypoints from being added.
+//
+// Processing:
+// The specified direct pinvoke, dpi, is ignored if it is empty or considered a comment.
+// It is then understood to be in the format of MODULE or MODULE!ENTRYPOINT.
+// A direct pinvoke in the form of MODULE is understood as enabling direct pinvoke for all
+// entrypoints within the particular module, and will override any previously added set
+// of entrypoint names.
+// A direct pinvoke in the form of MODULE!ENTRYPOINT is understood as enabling direct pinvoke
+// for the specific entrypoint in the module. It will not be added if the entire module
+// should be direct pinvoked, but otherwise will be added to a set of entrypoint names for
+// the particular module.
+//
+// Arguments:
+//  * acfg - the MonoAotCompiler instance
+//  * dpi (direct pinvoke) - the string passed in specifying a direct pinvoke
+//
+// Return Value:
+//  gboolean corresponding to whether the specified direct pinvoke was successfully
+//  processed (regardless of it being added or skipped) or if an error occurred.
+//
+//  Note - There are no extensive format checks, and is intended to behave akin to
+//  ConfigurablePInvokePolicy AddDirectPInvoke in NativeAOT
+//
+
+static gboolean
+process_specified_direct_pinvokes (MonoAotCompile *acfg, const char *dpi)
+{
+       gboolean result = FALSE;
+       char *direct_pinvoke = g_strdup (dpi);
+       if (direct_pinvoke && g_strstrip (direct_pinvoke)) {
+               if (is_direct_pinvoke_parsable (direct_pinvoke)) {
+                       char *module, *entrypoint;
+                       if (parsed_direct_pinvoke (acfg, direct_pinvoke, &module, &entrypoint))
+                               result = add_direct_pinvoke (acfg, &module, &entrypoint);
+               } else {
+                       result = TRUE;
+               }
+       }
+       g_free (direct_pinvoke);
+       return result;
+}
+
 static int
 aot_assembly (MonoAssembly *ass, guint32 jit_opts, MonoAotOptions *aot_options)
 {
@@ -14239,10 +14588,52 @@ aot_assembly (MonoAssembly *ass, guint32 jit_opts, MonoAotOptions *aot_options)
        }
 #endif
 
-       if (acfg->aot_opts.direct_pinvoke && !acfg->aot_opts.static_link) {
-               aot_printerrf (acfg, "The 'direct-pinvoke' AOT option also requires the 'static' AOT option.\n");
+       if ((is_direct_pinvoke_enabled (acfg)) && !acfg->aot_opts.static_link) {
+               aot_printerrf (acfg, "The 'direct-pinvoke' flag, 'direct-pinvokes', and 'direct-pinvoke-lists' AOT options also require the 'static' AOT option.\n");
                return 1;
        }
+       if (acfg->aot_opts.direct_pinvoke && (acfg->aot_opts.direct_pinvokes || acfg->aot_opts.direct_pinvoke_lists)) {
+               aot_printerrf (acfg, "The 'direct-pinvoke' flag trumps specified 'direct-pinvokes' and 'direct-pinvoke-lists' arguments. Unset either the flag or the specific direct pinvoke arguments.\n");
+               return 1;
+       }
+
+       gboolean added_direct_pinvoke = TRUE;
+       if (acfg->aot_opts.direct_pinvokes) {
+               for (GList *l = acfg->aot_opts.direct_pinvokes; l; l = l->next) {
+                       added_direct_pinvoke = process_specified_direct_pinvokes (acfg, (const char*)l->data);
+                       if (!added_direct_pinvoke)
+                               return 1;
+               }
+       }
+       if (acfg->aot_opts.direct_pinvoke_lists) {
+               for (GList *l = acfg->aot_opts.direct_pinvoke_lists; l; l = l->next) {
+                       const char *direct_pinvoke_list = (const char*)l->data;
+                       gchar *direct_pinvoke_list_content = NULL;
+                       gchar *direct_pinvoke_list_content_ctx = NULL;
+                       gchar *direct_pinvoke_list_content_line = NULL;
+
+                       if (!g_file_get_contents (direct_pinvoke_list, &direct_pinvoke_list_content, NULL, NULL)) {
+                               aot_printerrf (acfg, "Failed to open and read the provided 'direct-pinvoke-list' '%s'.\n", direct_pinvoke_list);
+                               return 1;
+                       }
+
+                       direct_pinvoke_list_content_line = strtok_r (direct_pinvoke_list_content, "\n", &direct_pinvoke_list_content_ctx);
+                       while (direct_pinvoke_list_content_line && added_direct_pinvoke) {
+                               // Strip whitespace from line read
+                               g_strstrip (direct_pinvoke_list_content_line);
+
+                               // Skip empty direct_pinvokes and comments
+                               if (direct_pinvoke_list_content_line [0] != '\0' && direct_pinvoke_list_content_line [0] != '#')
+                                       added_direct_pinvoke = process_specified_direct_pinvokes (acfg, direct_pinvoke_list_content_line);
+
+                               direct_pinvoke_list_content_line = strtok_r (NULL, "\n", &direct_pinvoke_list_content_ctx);
+                       }
+
+                       g_free (direct_pinvoke_list_content);
+                       if (!added_direct_pinvoke)
+                               return 1;
+               }
+       }
 
        if (acfg->aot_opts.static_link)
                acfg->aot_opts.asm_writer = TRUE;
@@ -14818,14 +15209,15 @@ emit_aot_image (MonoAotCompile *acfg)
 int
 mono_aot_assemblies (MonoAssembly **assemblies, int nassemblies, guint32 jit_opts, const char *aot_options)
 {
-       int res;
+       int res = 0;
        MonoAotOptions aot_opts;
 
        init_options (&aot_opts);
        mono_aot_parse_options (aot_options, &aot_opts);
        if (aot_opts.direct_extern_calls && !(aot_opts.llvm && aot_opts.static_link)) {
                fprintf (stderr, "The 'direct-extern-calls' option requires the 'llvm' and 'static' options.\n");
-               return 1;
+               res = 1;
+               goto early_exit;
        }
 
        if (aot_opts.dedup_include) {
@@ -14839,7 +15231,8 @@ mono_aot_assemblies (MonoAssembly **assemblies, int nassemblies, guint32 jit_opt
                }
                if (dedup_aindex == -1) {
                        fprintf (stderr, "Can't find --dedup-include assembly '%s' among the assemblies to be compiled.\n", aot_opts.dedup_include);
-                       return 1;
+                       res = 1;
+                       goto early_exit;
                }
 
                dedup_assembly = assemblies [dedup_aindex];
@@ -14856,10 +15249,14 @@ mono_aot_assemblies (MonoAssembly **assemblies, int nassemblies, guint32 jit_opt
                res = aot_assembly (assemblies [i], jit_opts, &aot_opts);
                if (res != 0) {
                        fprintf (stderr, "AOT of image %s failed.\n", assemblies [i]->image->name);
-                       return 1;
+                       res = 1;
+                       goto early_exit;
                }
        }
-       return 0;
+
+early_exit:
+       aot_opts_free (&aot_opts);
+       return res;
 }
 
 #else
index af73971..30db263 100644 (file)
@@ -107,6 +107,46 @@ public class MonoAOTCompiler : Microsoft.Build.Utilities.Task
     public bool UseDirectPInvoke { get; set; }
 
     /// <summary>
+    /// When this option is specified, the mono aot compiler will generate direct calls for only specified direct pinvokes.
+    /// Specified direct pinvokes can be in the format of 'module' to generate direct calls for all entrypoints in the module,
+    /// or 'module!entrypoint' to generate direct calls for individual entrypoints in a module. 'module' will trump 'module!entrypoint'.
+    /// For a direct call to be generated, the managed code must call the native function through a direct pinvoke, e.g.
+    ///
+    /// [DllImport("module", EntryPoint="entrypoint")]
+    /// public static extern <ret> ManagedName (arg)
+    ///
+    /// or
+    ///
+    /// [DllImport("module")]
+    /// public static extern <ret> entrypoint (arg)
+    ///
+    /// The native sources must be supplied in the direct pinvoke sources parammeter in the LibraryBuilder to generate a shared library.
+    /// If not using the LibraryBuilder, the native sources must be linked manually in the final executable or library.
+    /// This requires UseStaticLinking=true, can be used in conjunction with DirectPInvokeLists, but is incompatible with UseDirectPInvoke.
+    /// </summary>
+    public ITaskItem[] DirectPInvokes { get; set; } = Array.Empty<ITaskItem>();
+
+    /// <summary>
+    /// When this option is specified, the mono aot compiler will generate direct calls for only specified direct pinvokes in the provided files.
+    /// Specified direct pinvokes can be in the format of 'module' to generate direct calls for all entrypoints in the module,
+    /// or 'module!entrypoint' to generate direct calls for individual entrypoints in a module. 'module' will trump 'module!entrypoint'.
+    /// For a direct call to be generated, the managed code must call the native function through a direct pinvoke, e.g.
+    ///
+    /// [DllImport("module", EntryPoint="entrypoint")]
+    /// public static extern <ret> ManagedName (arg)
+    ///
+    /// or
+    ///
+    /// [DllImport("module")]
+    /// public static extern <ret> entrypoint (arg)
+    ///
+    /// The native sources must be supplied in the direct pinvoke sources parammeter in the LibraryBuilder to generate a shared library.
+    /// If not using the LibraryBuilder, the native sources must be linked manually in the final executable or library.
+    /// This requires UseStaticLinking=true, can be used in conjunction with DirectPInvokes, but is incompatible with UseDirectPInvoke.
+    /// </summary>
+    public ITaskItem[] DirectPInvokeLists { get; set; } = Array.Empty<ITaskItem>();
+
+    /// <summary>
     /// Instructs the AOT compiler to emit DWARF debugging information.
     /// </summary>
     public bool UseDwarfDebug { get; set; }
@@ -364,9 +404,21 @@ public class MonoAOTCompiler : Microsoft.Build.Utilities.Task
             throw new LogAsErrorException($"'{nameof(UseDirectIcalls)}' can only be used with '{nameof(UseStaticLinking)}=true'.");
         }
 
-        if (UseDirectPInvoke && !UseStaticLinking)
+        if (UseDirectPInvoke && (DirectPInvokes.Length > 0 || DirectPInvokeLists.Length > 0))
         {
-            throw new LogAsErrorException($"'{nameof(UseDirectPInvoke)}' can only be used with '{nameof(UseStaticLinking)}=true'.");
+            throw new LogAsErrorException($"'{nameof(UseDirectPInvoke)}' flag trumps specified '{nameof(DirectPInvokes)}' and '{nameof(DirectPInvokeLists)}' arguments. Unset either the flag or the specific direct pinvoke arguments.");
+        }
+
+        if (UseDirectPInvoke || DirectPInvokes.Length > 0 || DirectPInvokeLists.Length > 0)
+        {
+            if (!UseStaticLinking)
+                throw new LogAsErrorException($"'{nameof(UseDirectPInvoke)}', '{nameof(DirectPInvokes)}', and '{nameof(DirectPInvokeLists)}' can only be used with '{nameof(UseStaticLinking)}=true'.");
+
+            foreach (var directPInvokeList in DirectPInvokeLists)
+            {
+                if (!File.Exists(directPInvokeList.GetMetadata("FullPath")))
+                    throw new LogAsErrorException($"Could not find file '{directPInvokeList}'.");
+            }
         }
 
         if (UseStaticLinking && (parsedOutputType == MonoAotOutputType.Library))
@@ -611,6 +663,25 @@ public class MonoAOTCompiler : Microsoft.Build.Utilities.Task
             aotArgs.Add($"static");
         }
 
+        if (UseDirectPInvoke)
+        {
+            aotArgs.Add($"direct-pinvoke");
+        }
+
+        if (DirectPInvokes.Length > 0)
+        {
+            var directPInvokesSB = new StringBuilder("direct-pinvokes=");
+            Array.ForEach(DirectPInvokes, directPInvokeItem => directPInvokesSB.Append($"{directPInvokeItem.ItemSpec};"));
+            aotArgs.Add(directPInvokesSB.ToString());
+        }
+
+        if (DirectPInvokeLists.Length > 0)
+        {
+            var directPInvokeListsSB = new StringBuilder("direct-pinvoke-lists=");
+            Array.ForEach(DirectPInvokeLists, directPInvokeListItem => directPInvokeListsSB.Append($"{directPInvokeListItem.GetMetadata("FullPath")};"));
+            aotArgs.Add(directPInvokeListsSB.ToString());
+        }
+
         if (UseDwarfDebug)
         {
             aotArgs.Add($"dwarfdebug");