From: Bradley Davis Date: Tue, 17 Aug 2021 16:55:25 +0000 (-0700) Subject: [fx] persist `tracer_cls` on `fx.Graph` when deep copying (#63353) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~967 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=011fdc3b7e7bcd2528a2b0bf20d1c0793f6125e1;p=platform%2Fupstream%2Fpytorch.git [fx] persist `tracer_cls` on `fx.Graph` when deep copying (#63353) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63353 Custom deepcopy method copies all nodes but does not copy the tracer_cls attribute Reviewed By: houseroad Differential Revision: D30349424 fbshipit-source-id: 3e98bdac8a8a992eb0b4ec67fe80bb2e5cf3884d --- diff --git a/test/test_fx.py b/test/test_fx.py index cf69143..f0a3291 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -779,6 +779,19 @@ class TestFX(JitTestCase): copied = copy.deepcopy(traced) copied.graph.lint() + def test_deepcopy_graph_with_tracer_cls(self): + class TestTracer(Tracer): + def is_leaf_module(self, module, name): + return True + + g = Graph(tracer_cls=TestTracer) + x = g.placeholder("x") + g.output(x) + + h = copy.deepcopy(g) + self.assertIsNotNone(h._tracer_cls) + self.assertTrue(g._tracer_cls == h._tracer_cls) + def test_unpack_list_better_error(self): class SomeArgs(torch.nn.Module): def forward(self, a, b): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index a8d657d..88c7b54 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -357,7 +357,7 @@ class Graph: nodes or other parts of the Graph from a custom GraphModule implementation """ memo = memo if memo else {} - g = Graph() + g = Graph(tracer_cls=self._tracer_cls) output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) assert isinstance(output_vals, tuple) output_val, old_output_val = output_vals