[fx] store Tracer class on Graph and GraphModule for package deserialization [v2...
authorBradley Davis <bradleyhd@fb.com>
Fri, 13 Aug 2021 00:27:08 +0000 (17:27 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 00:28:50 +0000 (17:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63121

Re-introducing this diff with a small change to ignore setting Tracer classes on GraphModules when the Tracer class is defined not at module-level (prevents pickling).

Previous, reverted Pull Request: https://github.com/pytorch/pytorch/pull/62497

Reviewed By: houseroad

Differential Revision: D30252776

fbshipit-source-id: 42d2bc846e4b32d00563419c38c02b63cd0986e6

test/package/package_a/test_all_leaf_modules_tracer.py [new file with mode: 0644]
test/package/test_package_fx.py
torch/fx/_symbolic_trace.py
torch/fx/graph.py
torch/fx/graph_module.py

diff --git a/test/package/package_a/test_all_leaf_modules_tracer.py b/test/package/package_a/test_all_leaf_modules_tracer.py
new file mode 100644 (file)
index 0000000..ca8d8a0
--- /dev/null
@@ -0,0 +1,6 @@
+from torch.fx import Tracer
+
+
+class TestAllLeafModulesTracer(Tracer):
+    def is_leaf_module(self, m, qualname):
+        return True
index 7ad05a1..7f31014 100644 (file)
@@ -121,6 +121,45 @@ class TestPackageFX(PackageTestCase):
         packaged_dependency = pi.import_module("package_a.subpackage")
         self.assertTrue(packaged_dependency is not package_a.subpackage)
 
+    def test_package_fx_custom_tracer(self):
+        from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
+        from package_a.test_module import SimpleTest, ModWithTwoSubmodsAndTensor
+
+        class SpecialGraphModule(torch.fx.GraphModule):
+            def __init__(self, root, graph, info):
+                super().__init__(root, graph)
+                self.info = info
+
+        sub_module = SimpleTest()
+        module = ModWithTwoSubmodsAndTensor(
+            torch.ones(3),
+            sub_module,
+            sub_module,
+        )
+        tracer = TestAllLeafModulesTracer()
+        graph = tracer.trace(module)
+
+        self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
+
+        gm = SpecialGraphModule(module, graph, "secret")
+        self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
+
+        f = BytesIO()
+        with PackageExporter(f) as pe:
+            pe.intern("**")
+            pe.save_pickle("model", "model.pkl", gm)
+        f.seek(0)
+
+        pi = PackageImporter(f)
+        loaded_gm = pi.load_pickle("model", "model.pkl")
+        self.assertEqual(
+            type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
+        )
+        self.assertEqual(loaded_gm.info, "secret")
+
+        input_x = torch.randn(3)
+        self.assertTrue(torch.allclose(loaded_gm(input_x), gm(input_x)))
+
 
 if __name__ == "__main__":
     run_tests()
index 56f925f..3098bea 100644 (file)
@@ -532,7 +532,9 @@ class Tracer(TracerBase):
         else:
             self.root = torch.nn.Module()
             fn = root
-        self.graph = Graph()
+
+        tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
+        self.graph = Graph(tracer_cls=tracer_cls)
 
         # When we encounter a Tensor value that's not a parameter, we look if it
         # is some other attribute on the model. Construct a dict mapping Tensor
index b7960f8..a8d657d 100644 (file)
@@ -2,7 +2,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
 import torch.utils._pytree as pytree
 from . import _pytree as fx_pytree
 
-from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet
+from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
 from dataclasses import dataclass
 from contextlib import contextmanager
 import copy
@@ -16,6 +16,7 @@ import warnings
 
 if TYPE_CHECKING:
     from .graph_module import GraphModule  # noqa: F401
+    from ._symbolic_trace import Tracer   # noqa: F401
 
 
 # Mapping of builtins to their `typing` equivalent.
@@ -282,7 +283,7 @@ class Graph:
 
     For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
     """
-    def __init__(self, owning_module: Optional["GraphModule"] = None):
+    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None):
         """
         Construct an empty Graph.
         """
@@ -293,6 +294,7 @@ class Graph:
         self._graph_namespace = _Namespace()
         self._owners = 0
         self._owning_module = owning_module
+        self._tracer_cls = tracer_cls
         self._pytree_info: Optional[_PyTreeInfo] = None
 
     @property
index 09d1b6d..0cbbd93 100644 (file)
@@ -71,26 +71,25 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_package_graph_module(importer: PackageImporter,
-                                body: Dict[Any, Any],
-                                generated_module_name: str) -> torch.nn.Module:
+def reduce_package_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
+) -> torch.nn.Module:
     forward = importer.import_module(generated_module_name).forward
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_deploy_graph_module(importer: PackageImporter,
-                               body: Dict[Any, Any],
-                               import_block: str,
-                               tracer_cls: Type) -> torch.nn.Module:
+def reduce_deploy_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], import_block: str
+) -> torch.nn.Module:
     ns = dict()
     ns["__builtins__"] = importer.patched_builtins
     fn_src = body.get('_code')
     assert fn_src is not None
     forward = _forward_from_src(import_block + fn_src, ns)
-    return _deserialize_graph_module(forward, body, tracer_cls)
+    return _deserialize_graph_module(forward, body)
 
 
-def _deserialize_graph_module(forward, body: Dict[Any, Any], tracer_cls: Type = None) -> torch.nn.Module:
+def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
     """
     Deserialize a GraphModule given the dictionary of the original module,
     using the code to reconstruct the graph. We delete the actual graph before
@@ -107,10 +106,13 @@ def _deserialize_graph_module(forward, body: Dict[Any, Any], tracer_cls: Type =
     # Try to retrieve the forward source in a backward-compatible way
     CodeOnlyModule.forward = forward
 
+    tracer_cls = body.get('_tracer_cls')
     if tracer_cls is None:
         from ._symbolic_trace import Tracer
         tracer_cls = Tracer
 
+    graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
+
     # This is a workaround for a mypy linter issue related to
     # passing base class as an argument - https://github.com/python/mypy/issues/5865.
     cls_tracer : Any = tracer_cls
@@ -122,7 +124,22 @@ def _deserialize_graph_module(forward, body: Dict[Any, Any], tracer_cls: Type =
             return True
 
     com = CodeOnlyModule(body)
-    return GraphModule(com, KeepModules().trace(com))
+
+    graph = KeepModules().trace(com)
+
+    # Manually set Tracer class on the reconstructed Graph, to avoid
+    # referencing the private local subclass KeepModules.
+    graph._tracer_cls = tracer_cls
+    gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
+
+    # The GraphModule constructor only retains attributes referenced by the graph.
+    # In this case, our goal is return a GraphModule as close to identical as the one
+    # put into the package. If any additional attributes were present in body,
+    # we should keep them.
+    for k, v in body.items():
+        if not hasattr(gm, k):
+            setattr(gm, k, v)
+    return gm
 
 # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
 # This installs empty Modules where none exist yet if they are subpaths of target
@@ -251,6 +268,15 @@ class GraphModule(torch.nn.Module):
 
         self.graph = graph
 
+        # Store the Tracer class responsible for creating a Graph separately as part of the
+        # GraphModule state, except when the Tracer is defined in a local namespace.
+        # Locally defined Tracers are not pickleable. This is needed because torch.package will
+        # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
+        # to re-create the Graph during deserialization.
+        self._tracer_cls = None
+        if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
+            self._tracer_cls = self.graph._tracer_cls
+
     # TorchScript breaks trying to compile the graph setter because of the
     # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
     #
@@ -548,22 +574,25 @@ class {module_name}(torch.nn.Module):
 
     # Passing Tracer as argument allows subclasses extending fx.GraphModule
     # define their own Tracer (extending fx.Tracer).
-    def __reduce_deploy__(self, importer: Importer, tracer_cls: Type = None):
+    def __reduce_deploy__(self, importer: Importer):
         dict_without_graph = self.__dict__.copy()
+        dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
+        del dict_without_graph['_graph']
+
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, importer)
-        del dict_without_graph['_graph']
-        return (reduce_deploy_graph_module, (dict_without_graph, import_block, tracer_cls))
+        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
 
     def __reduce_package__(self, exporter: PackageExporter):
+        dict_without_graph = self.__dict__.copy()
+        dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
+        del dict_without_graph['_graph']
+
         generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, exporter.importer)
         module_code = import_block + self.code
         exporter.save_source_string(generated_module_name, module_code)
-
-        dict_without_graph = self.__dict__.copy()
-        del dict_without_graph['_graph']
         return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
 
     def __reduce__(self):