Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63353
Custom deepcopy method copies all nodes but does not copy the tracer_cls attribute
Reviewed By: houseroad
Differential Revision:
D30349424
fbshipit-source-id:
3e98bdac8a8a992eb0b4ec67fe80bb2e5cf3884d
copied = copy.deepcopy(traced)
copied.graph.lint()
+ def test_deepcopy_graph_with_tracer_cls(self):
+ class TestTracer(Tracer):
+ def is_leaf_module(self, module, name):
+ return True
+
+ g = Graph(tracer_cls=TestTracer)
+ x = g.placeholder("x")
+ g.output(x)
+
+ h = copy.deepcopy(g)
+ self.assertIsNotNone(h._tracer_cls)
+ self.assertTrue(g._tracer_cls == h._tracer_cls)
+
def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
nodes or other parts of the Graph from a custom GraphModule implementation
"""
memo = memo if memo else {}
- g = Graph()
+ g = Graph(tracer_cls=self._tracer_cls)
output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
assert isinstance(output_vals, tuple)
output_val, old_output_val = output_vals