From 0c4e4e588e2e3308c659f741e7bc5cabf0975c09 Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 31 Aug 2021 22:20:41 -0700 Subject: [PATCH] [FX] Rename reduce functions back to their old, public names (#64324) Summary: Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324 Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/ Reviewed By: SplitInfinity, TailofJune Differential Revision: D30684185 Pulled By: jamesr66a fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f --- ..._back_compat-fx_backcompat_function_signatures.expect | 3 +++ torch/fx/graph_module.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a73fde7..20d392f 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -28,6 +28,9 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode +torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module +torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module +torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True) torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 89685bf..ca82d49 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -96,7 +96,8 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer): return '\n'.join(import_strs) -def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: # BC: attribute name was changed from `code` to `_code` to facilitate # making `code` into a property and adding a docstring to it fn_src = body.get('_code') or body['code'] @@ -104,14 +105,15 @@ def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mo return _deserialize_graph_module(forward, body) -def _reduce_package_graph_module( +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str ) -> torch.nn.Module: forward = importer.import_module(generated_module_name).forward return _deserialize_graph_module(forward, body) - -def _reduce_deploy_graph_module( +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( importer: PackageImporter, body: Dict[Any, Any], import_block: str ) -> torch.nn.Module: ns = dict() @@ -626,7 +628,7 @@ class {module_name}(torch.nn.Module): python_code = self.recompile() import_block = _format_import_block(python_code.globals, importer) - return (_reduce_deploy_graph_module, (dict_without_graph, import_block)) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) def __reduce_package__(self, exporter: PackageExporter): dict_without_graph = self.__dict__.copy() @@ -638,7 +640,7 @@ class {module_name}(torch.nn.Module): import_block = _format_import_block(python_code.globals, exporter.importer) module_code = import_block + self.code exporter.save_source_string(generated_module_name, module_code) - return (_reduce_package_graph_module, (dict_without_graph, generated_module_name)) + return (reduce_package_graph_module, (dict_without_graph, generated_module_name)) def __reduce__(self): """ @@ -652,7 +654,7 @@ class {module_name}(torch.nn.Module): python_code = self.recompile() import_block = _format_import_block(python_code.globals, sys_importer) del dict_without_graph['_graph'] - return (_reduce_graph_module, (dict_without_graph, import_block)) + return (reduce_graph_module, (dict_without_graph, import_block)) # because __reduce__ is defined for serialization, # we need to define deepcopy otherwise it will call __reduce__ -- 2.7.4