[FX] Fix GraphModule deepcopy to use deepcopied graph (#63090)
authorJames Reed <jamesreed@fb.com>
Wed, 18 Aug 2021 20:16:01 +0000 (13:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 20:17:14 +0000 (13:17 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63090

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D30252471

Pulled By: jamesr66a

fbshipit-source-id: cafd7d7917935a5ea6ffa2a7fe9e9b2a9578b3e3

test/test_fx.py
torch/fx/graph_module.py

index f0a3291..1708634 100644 (file)
@@ -1943,6 +1943,25 @@ class TestFX(JitTestCase):
         with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
             traced_graph = MyTracer().trace(CallsModWithDict())
 
+    def test_module_deepcopy_edit_nodes(self):
+        class Foo(torch.nn.Module):
+            def forward(self, x):
+                return torch.relu(x)
+
+        traced1 = symbolic_trace(Foo())
+        copied = copy.deepcopy(traced1)
+
+        for node in copied.graph.nodes:
+            if node.target == torch.relu:
+                node.target = torch.neg
+
+        copied.recompile()
+        traced1.recompile()
+
+        x = torch.randn(15, 15)
+        torch.testing.assert_allclose(traced1(x), torch.relu(x))
+        torch.testing.assert_allclose(copied(x), torch.neg(x))
+
     def test_direct_param_use(self):
         class TransposeTest(torch.nn.Module):
             def __init__(self):
index 0cbbd93..85479f0 100644 (file)
@@ -615,7 +615,7 @@ class {module_name}(torch.nn.Module):
     def __deepcopy__(self, memo):
         fake_mod = torch.nn.Module()
         fake_mod.__dict__ = copy.deepcopy(self.__dict__)
-        return GraphModule(fake_mod, self.graph)
+        return GraphModule(fake_mod, fake_mod.__dict__['_graph'])
 
     def __copy__(self):
         return GraphModule(self, self.graph)