From 4e37a015c7ed7ac4a4f57057cb726f53b94e693b Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 24 Aug 2021 13:44:52 -0700 Subject: [PATCH] [FX] Fix _replicate_for_data_parallel (#63821) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D30502115 Pulled By: jamesr66a fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c --- test/test_fx.py | 15 +++++++++++++++ torch/fx/graph_module.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/test/test_fx.py b/test/test_fx.py index c55e97d..27f64e1 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2296,6 +2296,21 @@ class TestFX(JitTestCase): r"Call using an FX-traced Module, line .* of the " r"traced Module's generated forward function:") + def test_graph_module_replicate_for_dp(self): + class Foo(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + gm = torch.fx.symbolic_trace(Foo()) + + x = torch.randn(5, 3) + out = gm(x) + + replica = gm._replicate_for_data_parallel() + out_replica = replica(x) + + torch.testing.assert_allclose(out_replica, out) + def test_ast_rewriter_rewrites_assert(self): class M(torch.nn.Module): def forward(self, x: torch.Tensor, y: int, z: int): diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index b87aeaa..c918573 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -656,6 +656,11 @@ class {module_name}(torch.nn.Module): orig_str = super().__str__() return '\n'.join([orig_str, self._code]) + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + # workarounds for issues in __torch_function__ # WAR for __torch_function__ not handling tensor lists, -- 2.7.4