[FX] Rename reduce functions back to their old, public names (#64324)
authorJames Reed <jamesreed@fb.com>
Wed, 1 Sep 2021 05:20:41 +0000 (22:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 05:36:11 +0000 (22:36 -0700)
Summary:
Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324

Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/

Reviewed By: SplitInfinity, TailofJune

Differential Revision: D30684185

Pulled By: jamesr66a

fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f

test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect
torch/fx/graph_module.py

index a73fde7..20d392f 100644 (file)
@@ -28,6 +28,9 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m
 torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
 torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool
 torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode
+torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
+torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
+torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module
 torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
 torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
 torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
index 89685bf..ca82d49 100644 (file)
@@ -96,7 +96,8 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
     return '\n'.join(import_strs)
 
 
-def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
+@compatibility(is_backward_compatible=True)
+def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
     # BC: attribute name was changed from `code` to `_code` to facilitate
     # making `code` into a property and adding a docstring to it
     fn_src = body.get('_code') or body['code']
@@ -104,14 +105,15 @@ def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mo
     return _deserialize_graph_module(forward, body)
 
 
-def _reduce_package_graph_module(
+@compatibility(is_backward_compatible=True)
+def reduce_package_graph_module(
     importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
 ) -> torch.nn.Module:
     forward = importer.import_module(generated_module_name).forward
     return _deserialize_graph_module(forward, body)
 
-
-def _reduce_deploy_graph_module(
+@compatibility(is_backward_compatible=True)
+def reduce_deploy_graph_module(
     importer: PackageImporter, body: Dict[Any, Any], import_block: str
 ) -> torch.nn.Module:
     ns = dict()
@@ -626,7 +628,7 @@ class {module_name}(torch.nn.Module):
 
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, importer)
-        return (_reduce_deploy_graph_module, (dict_without_graph, import_block))
+        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
 
     def __reduce_package__(self, exporter: PackageExporter):
         dict_without_graph = self.__dict__.copy()
@@ -638,7 +640,7 @@ class {module_name}(torch.nn.Module):
         import_block = _format_import_block(python_code.globals, exporter.importer)
         module_code = import_block + self.code
         exporter.save_source_string(generated_module_name, module_code)
-        return (_reduce_package_graph_module, (dict_without_graph, generated_module_name))
+        return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
 
     def __reduce__(self):
         """
@@ -652,7 +654,7 @@ class {module_name}(torch.nn.Module):
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, sys_importer)
         del dict_without_graph['_graph']
-        return (_reduce_graph_module, (dict_without_graph, import_block))
+        return (reduce_graph_module, (dict_without_graph, import_block))
 
     # because __reduce__ is defined for serialization,
     # we need to define deepcopy otherwise it will call __reduce__