From 011fdc3b7e7bcd2528a2b0bf20d1c0793f6125e1 Mon Sep 17 00:00:00 2001 From: Bradley Davis Date: Tue, 17 Aug 2021 09:55:25 -0700 Subject: [PATCH] [fx] persist `tracer_cls` on `fx.Graph` when deep copying (#63353) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63353 Custom deepcopy method copies all nodes but does not copy the tracer_cls attribute Reviewed By: houseroad Differential Revision: D30349424 fbshipit-source-id: 3e98bdac8a8a992eb0b4ec67fe80bb2e5cf3884d --- test/test_fx.py | 13 +++++++++++++ torch/fx/graph.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_fx.py b/test/test_fx.py index cf69143..f0a3291 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -779,6 +779,19 @@ class TestFX(JitTestCase): copied = copy.deepcopy(traced) copied.graph.lint() + def test_deepcopy_graph_with_tracer_cls(self): + class TestTracer(Tracer): + def is_leaf_module(self, module, name): + return True + + g = Graph(tracer_cls=TestTracer) + x = g.placeholder("x") + g.output(x) + + h = copy.deepcopy(g) + self.assertIsNotNone(h._tracer_cls) + self.assertTrue(g._tracer_cls == h._tracer_cls) + def test_unpack_list_better_error(self): class SomeArgs(torch.nn.Module): def forward(self, a, b): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index a8d657d..88c7b54 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -357,7 +357,7 @@ class Graph: nodes or other parts of the Graph from a custom GraphModule implementation """ memo = memo if memo else {} - g = Graph() + g = Graph(tracer_cls=self._tracer_cls) output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) assert isinstance(output_vals, tuple) output_val, old_output_val = output_vals -- 2.7.4