make test module hook use save/load (#18284)
authorMichael Suo <suo@fb.com>
Wed, 3 Apr 2019 01:06:07 +0000 (18:06 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 01:09:52 +0000 (18:09 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18284
ghimport-source-id: 5a92c03fda19072ffb6afd40e0f56806716c7be6

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18296 [jit] Add namespacing for ScriptClasses
* **#18284 [jit] make test module hook use save/load**
* #18211 [jit] Turn script_type_parser into a class
* #18148 [jit] python interop for script classes

Instead of python-printing and comparing strings (which does not capture
depdency information, etc.), use save/load on in-memory buffers and
compare the main module contents inside the buffer

Reviewed By: ailzhang

Differential Revision: D14581129

fbshipit-source-id: 52264ae9ce076775ab3fd1a0c32c8d6f6677a903

test/test_jit.py

index 22ce30a..fef1c79 100644 (file)
@@ -281,6 +281,8 @@ class JitTestCase(TestCase):
         torch._C._jit_set_emit_module_hook(self.emitModuleHook)
 
     def emitModuleHook(self, module):
+        import zipfile
+
         def copy_structure_and_params(m):
             c = torch.jit.ScriptModule()
             for name, v in m._get_parameters():
@@ -294,7 +296,24 @@ class JitTestCase(TestCase):
         # disable the hook while we parse code, otherwise we will re-enter the hook
         with self.disableModuleHook():
             try:
-                pp, constant_table = module._python_print()
+                if len(module.code) == 0:
+                    # short-circuit if this is an empty module
+                    return
+                # save the module to a buffer
+                buffer = io.BytesIO()
+                torch.jit.save(module, buffer)
+
+                # copy the data in the buffer so we can restore it later. This
+                # is because py2 and py3 have different semantics with zipfile
+                # and it's easier to just work with a fresh copy each time.
+                buffer_copy = buffer.getvalue()
+
+                # crack open the zip format to get at the main module code
+                archive = zipfile.ZipFile(buffer)
+                main_module = archive.open('archive/code/archive.py')
+                main_module_code = ""
+                for line in main_module:
+                    main_module_code += line.decode()
             except RuntimeError as e:
                 se = str(e)
                 if "could not export python function" not in se and \
@@ -302,12 +321,24 @@ class JitTestCase(TestCase):
                     raise
                 else:
                     return
-            ppv = "op_version_set = 0\n{}".format(pp)
-            sm = copy_structure_and_params(module)
-            torch._C._jit_import_methods(sm, ppv, constant_table)
-            pp2, _ = sm._python_print()
-            if pp != pp2:
-                self.assertMultiLineEqual(pp, pp2)
+
+            # import the model again (from a the copy we made of the original)
+            buffer2 = io.BytesIO(buffer_copy)
+            imported = torch.jit.load(buffer2)
+
+            # save it again
+            saved_module_buffer_2 = io.BytesIO()
+            torch.jit.save(imported, saved_module_buffer_2)
+
+            saved_module_buffer_2.seek(0)
+            archive2 = zipfile.ZipFile(saved_module_buffer_2)
+            main_module_2 = archive2.open('archive/code/archive.py')
+
+            main_module_2_code = ""
+            for line in main_module_2:
+                main_module_2_code += line.decode()
+
+            self.assertMultiLineEqual(main_module_code, main_module_2_code)
 
     def getExportImportCopy(self, m, also_test_file=True, map_location=None):
         buffer = io.BytesIO()