torch._C._jit_set_emit_module_hook(self.emitModuleHook)
def emitModuleHook(self, module):
+ import zipfile
+
def copy_structure_and_params(m):
c = torch.jit.ScriptModule()
for name, v in m._get_parameters():
# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableModuleHook():
try:
- pp, constant_table = module._python_print()
+ if len(module.code) == 0:
+ # short-circuit if this is an empty module
+ return
+ # save the module to a buffer
+ buffer = io.BytesIO()
+ torch.jit.save(module, buffer)
+
+ # copy the data in the buffer so we can restore it later. This
+ # is because py2 and py3 have different semantics with zipfile
+ # and it's easier to just work with a fresh copy each time.
+ buffer_copy = buffer.getvalue()
+
+ # crack open the zip format to get at the main module code
+ archive = zipfile.ZipFile(buffer)
+ main_module = archive.open('archive/code/archive.py')
+ main_module_code = ""
+ for line in main_module:
+ main_module_code += line.decode()
except RuntimeError as e:
se = str(e)
if "could not export python function" not in se and \
raise
else:
return
- ppv = "op_version_set = 0\n{}".format(pp)
- sm = copy_structure_and_params(module)
- torch._C._jit_import_methods(sm, ppv, constant_table)
- pp2, _ = sm._python_print()
- if pp != pp2:
- self.assertMultiLineEqual(pp, pp2)
+
+ # import the model again (from a the copy we made of the original)
+ buffer2 = io.BytesIO(buffer_copy)
+ imported = torch.jit.load(buffer2)
+
+ # save it again
+ saved_module_buffer_2 = io.BytesIO()
+ torch.jit.save(imported, saved_module_buffer_2)
+
+ saved_module_buffer_2.seek(0)
+ archive2 = zipfile.ZipFile(saved_module_buffer_2)
+ main_module_2 = archive2.open('archive/code/archive.py')
+
+ main_module_2_code = ""
+ for line in main_module_2:
+ main_module_2_code += line.decode()
+
+ self.assertMultiLineEqual(main_module_code, main_module_2_code)
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()