[const_fold] Keep around node.meta for replaced folded ops (#64782)
authorJordan Fix <jfix@fb.com>
Fri, 10 Sep 2021 06:49:22 +0000 (23:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 06:52:39 +0000 (23:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64782

Previously, get_attrs that were added to the graph did not retain node.meta after folding. Add such support, and improve coverage in general here.

Test Plan: Added test coverage.

Reviewed By: protonu

Differential Revision: D30852704

fbshipit-source-id: ece87a61c69b2e68982964c6adc4dde14dae12c7

test/fx/test_fx_const_fold.py
torch/fx/experimental/const_fold.py

index a05b9ae..67f795b 100644 (file)
@@ -1,6 +1,8 @@
+import operator
 import unittest
 
 import torch
+import torch.fx
 from torch.fx.experimental import const_fold
 
 
@@ -337,3 +339,56 @@ class TestConstFold(unittest.TestCase):
         fold_result = mod_folded(in_x)
         base_result = mod(in_x)
         self.assertTrue(torch.equal(fold_result, base_result))
+
+    def test_retain_node_meta(self):
+        r"""
+        Perform constant folding conversion, and validate that node meta is retained.
+        """
+
+        class ConstFoldTestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.attr = torch.nn.Parameter(torch.randn(2, 3))
+
+            def forward(self, x):
+                a = self.attr + self.attr
+                return x - a
+
+        mod = ConstFoldTestModule()
+        gm = torch.fx.symbolic_trace(mod)
+
+        # Add a count for each node to check after we const fold.
+        for idx, node in enumerate(gm.graph.nodes):
+            if node.op != "output":
+                node.meta["meta_idx"] = idx
+
+        # Pre-folding:
+        # idx 0: placeholder
+        # idx 1: get_attr (will no longer be used, hence removed)
+        # idx 2: add (will be folded into a get_attr)
+        # idx 3: sub
+
+        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
+        self._verify_const_fold_mod(gm_folded)
+
+        # Post-folding:
+        # idx 0: placeholder
+        # idx 2: get_attr (replaced original add; original get_attr was removed)
+        # idx 3: sub
+
+        # Check the expected indices are still here.
+        for node in gm_folded.graph.nodes:
+            if node.op == "placeholder":
+                self.assertEqual(node.meta["meta_idx"], 0)
+            elif node.op == "get_attr":
+                self.assertEqual(node.meta["meta_idx"], 2)
+            elif node.op == "call_function" and node.target == operator.sub:
+                self.assertEqual(node.meta["meta_idx"], 3)
+            else:
+                self.assertEqual(node.op, "output")
+
+        # Now run both folded and non-folded to check results equal.
+        in_x = torch.randn(2, 3)
+        fold_result = gm_folded(in_x)
+        base_result = mod(in_x)
+        self.assertTrue(torch.equal(fold_result, base_result))
index b1446ef..8469329 100644 (file)
@@ -1,9 +1,9 @@
 import operator
+import re
 from typing import Dict, Set, List, Optional, Union
 
 import torch.fx
 from torch.fx.passes.split_module import split_module
-import re
 
 
 def _make_tuple(x):
@@ -218,7 +218,9 @@ def split_const_subgraphs(
                 torch.nn.Parameter(torch.randn(1)),
             )
         with split.submod_1.graph.inserting_before(node):
-            node.replace_all_uses_with(split.submod_1.graph.get_attr(const_output_name))
+            new_node = split.submod_1.graph.get_attr(const_output_name)
+            new_node.meta = node.meta.copy()
+            node.replace_all_uses_with(new_node)
         split.submod_1.graph.erase_node(node)
         ph_idx += 1