Make frozen symbol name customizable in torch deploy. (#63817)
authorZhengxu Chen <zhxchen17@fb.com>
Thu, 26 Aug 2021 03:09:12 +0000 (20:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 03:10:35 +0000 (20:10 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63817

ghstack-source-id: 136699671

Test Plan: eyes

Reviewed By: wconstab

Differential Revision: D29571559

fbshipit-source-id: 8e3caa4932ef8d7c8559f264f0e9bb5474ad2237

torch/csrc/deploy/interpreter/freeze.py

index 24fa709..3153174 100644 (file)
@@ -35,17 +35,13 @@ MAIN_INCLUDES = """#include <Python.h>
 
 """
 
-MAIN_PREFIX = """
+MAIN_PREFIX_TEMPLATE = """
 // Compiled standard library modules. These should be appended to the existing
 // `PyImport_FrozenModules` that ships with CPython.
-struct _frozen _PyImport_FrozenModules_torch[] = {
+struct _frozen {}[] = {{
 """
 
-FAKE_PREFIX = """
-// Compiled standard library modules. These should be appended to the existing
-// `PyImport_FrozenModules` that ships with CPython.
-struct _frozen _PyImport_FrozenModules[] = {
-"""
+FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules")
 
 MAIN_SUFFIX = """\
     {0, 0, 0} /* sentinel */
@@ -133,7 +129,7 @@ class Freezer:
         for f in bytecode_files:
             f.close()
 
-    def write_main(self, install_root, oss):
+    def write_main(self, install_root, oss, symbol_name):
         """
         Write the `main.c` file containing a table enumerating all the
         frozen modules.
@@ -143,7 +139,7 @@ class Freezer:
             for m in self.frozen_modules:
                 outfp.write(f"extern unsigned char {m.c_name}[];\n")
 
-            outfp.write(MAIN_PREFIX)
+            outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name))
             for m in self.frozen_modules:
                 outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n')
             outfp.write(MAIN_SUFFIX)
@@ -246,6 +242,11 @@ parser.add_argument("paths", nargs="*", help="Paths to freeze.")
 parser.add_argument("--verbose", action="store_true", help="Print debug logs")
 parser.add_argument("--install_dir", help="Root directory for all output files")
 parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
+parser.add_argument(
+    "--symbol_name",
+    help="The name of the frozen module array symbol to generate",
+    default="_PyImport_FrozenModules_torch",
+)
 
 args = parser.parse_args()
 
@@ -264,4 +265,4 @@ for p in args.paths:
         f.compile_path(path, path)
 
 f.write_bytecode(args.install_dir)
-f.write_main(args.install_dir, args.oss)
+f.write_main(args.install_dir, args.oss, args.symbol_name)