[fx_const_fold] Fix constant folding for attrs in submodule hierarchies (#64342)
authorJordan Fix <jfix@fb.com>
Wed, 8 Sep 2021 05:43:04 +0000 (22:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 05:44:39 +0000 (22:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64342

Previously we weren't handling the case where an attribute was in a module that wasn't the root.

Test Plan: Added unit test coverage.

Reviewed By: yinghai

Differential Revision: D30691730

fbshipit-source-id: b39b5cf748c4c882f315a4f32b51ad88cc7a43ed

test/fx/test_fx_const_fold.py
torch/fx/experimental/const_fold.py

index db06663..1da8109 100644 (file)
@@ -272,3 +272,37 @@ class TestConstFold(unittest.TestCase):
         fold_result = mod_folded(in_x, in_y)
         base_result = mod(in_x, in_y)
         self.assertTrue(torch.equal(fold_result, base_result))
+
+    def test_const_fold_submod_hierarchy(self):
+        r"""
+        Perform constant folding conversion, from original mod to split constant folding
+        module where one of the folded attrs comes from a submod deeper in the hierarchy
+        of the base module.
+        """
+
+        class TracedThroughModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.internal_attr = torch.nn.Parameter(torch.randn(2, 3))
+
+            def forward(self):
+                return self.internal_attr
+
+        class ConstFoldTestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.my_mod = TracedThroughModule()
+                self.attr = torch.nn.Parameter(torch.randn(2, 3))
+
+            def forward(self, x):
+                return self.attr + self.my_mod() + x
+
+        mod = ConstFoldTestModule()
+        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
+        self._verify_const_fold_mod(mod_folded)
+
+        # Now run both folded and non-folded to check results equal.
+        in_x = torch.randn(2, 3)
+        fold_result = mod_folded(in_x)
+        base_result = mod(in_x)
+        self.assertTrue(torch.equal(fold_result, base_result))
index 3fa46ed..2ecf88b 100644 (file)
@@ -1,5 +1,5 @@
 import operator
-from typing import Dict, Set, List, Optional
+from typing import Dict, Set, List, Optional, Union
 
 import torch.fx
 from torch.fx.passes.split_module import split_module
@@ -67,7 +67,7 @@ class FoldedGraphModule(torch.fx.GraphModule):
 
 
 def split_const_subgraphs(
-    module: torch.nn.Module,
+    module: Union[torch.nn.Module, torch.fx.GraphModule]
 ) -> FoldedGraphModule:
     """
     Looks through `module` for any nodes that have all constant attribute inputs
@@ -76,7 +76,10 @@ def split_const_subgraphs(
     attributes on the module prior to running the non-constant portion of the
     graph.
     """
-    mod_traced = torch.fx.symbolic_trace(module)
+    if not isinstance(module, torch.fx.GraphModule):
+        mod_traced = torch.fx.symbolic_trace(module)
+    else:
+        mod_traced = module
 
     # Build up a list of const_nodes, defined as nodes that are themselves
     # get_attrs, or have all get_attr or other constant node inputs.
@@ -255,15 +258,19 @@ def split_const_subgraphs(
     # somehow a priori knowing the attrs that should be passed as args. We can
     # unconditionally do this for all placeholders because we know all
     # placeholders to submod_0 must be constants accessible via get_attr.
-    for node in split.submod_0.graph.nodes:
+    # Note that here we set the split.submod_0.graph into a new root_submod_0 with split
+    # as the root module, because we are fetching attributes directly from the root
+    # module, instead of fetching them from split.submod_0.
+    root_submod_0 = torch.fx.GraphModule(split, split.submod_0.graph)
+    for node in root_submod_0.graph.nodes:
         if node.op != "placeholder":
             continue
         in_node = next(n for n in call_submod_0_args if n.name == node.target)
         assert in_node.op == "get_attr"
-        with split.submod_0.graph.inserting_before(node):
-            node.replace_all_uses_with(split.submod_0.graph.get_attr(in_node.target))
-        split.submod_0.graph.erase_node(node)
+        with root_submod_0.graph.inserting_before(node):
+            node.replace_all_uses_with(root_submod_0.graph.get_attr(in_node.target))
+        root_submod_0.graph.erase_node(node)
 
     return FoldedGraphModule(
-        mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names
+        mod_traced, split.submod_1.graph, root_submod_0.graph, const_output_names
     )