[fx] persist `tracer_cls` on `fx.Graph` when deep copying (#63353)
authorBradley Davis <bradleyhd@fb.com>
Tue, 17 Aug 2021 16:55:25 +0000 (09:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 16:57:48 +0000 (09:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63353

Custom deepcopy method copies all nodes but does not copy the tracer_cls attribute

Reviewed By: houseroad

Differential Revision: D30349424

fbshipit-source-id: 3e98bdac8a8a992eb0b4ec67fe80bb2e5cf3884d

test/test_fx.py
torch/fx/graph.py

index cf69143..f0a3291 100644 (file)
@@ -779,6 +779,19 @@ class TestFX(JitTestCase):
         copied = copy.deepcopy(traced)
         copied.graph.lint()
 
+    def test_deepcopy_graph_with_tracer_cls(self):
+        class TestTracer(Tracer):
+            def is_leaf_module(self, module, name):
+                return True
+
+        g = Graph(tracer_cls=TestTracer)
+        x = g.placeholder("x")
+        g.output(x)
+
+        h = copy.deepcopy(g)
+        self.assertIsNotNone(h._tracer_cls)
+        self.assertTrue(g._tracer_cls == h._tracer_cls)
+
     def test_unpack_list_better_error(self):
         class SomeArgs(torch.nn.Module):
             def forward(self, a, b):
index a8d657d..88c7b54 100644 (file)
@@ -357,7 +357,7 @@ class Graph:
         nodes or other parts of the Graph from a custom GraphModule implementation
         """
         memo = memo if memo else {}
-        g = Graph()
+        g = Graph(tracer_cls=self._tracer_cls)
         output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
         assert isinstance(output_vals, tuple)
         output_val, old_output_val = output_vals