[fx] make const fold code more pythonic (#64451)
authorYinghai Lu <yinghai@fb.com>
Wed, 8 Sep 2021 20:50:46 +0000 (13:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 20:55:10 +0000 (13:55 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64451

No functional change.

Test Plan:
```
buck test caffe2/test:fx_const_fold
```

Reviewed By: jfix71, RoshanPAN, houseroad

Differential Revision: D30718255

fbshipit-source-id: 95f98561c7f33fcc6c839db68683c85eb152c949

test/fx/test_fx_const_fold.py
torch/fx/experimental/const_fold.py

index 1da8109..a05b9ae 100644 (file)
@@ -105,6 +105,37 @@ class TestConstFold(unittest.TestCase):
         fold_result = mod_folded(in_x, in_y)
         self.assertTrue(torch.equal(fold_result, base_result))
 
+    def test_const_fold_basic_placeholder_reordered(self):
+        """
+        Test code path where placeholder comes after normal op node in FX
+        """
+        class ConstFoldTestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x, y):
+                return x * 2 + y
+
+        mod = ConstFoldTestModule()
+        mod = torch.fx.symbolic_trace(mod)
+        yy = None
+        for n in mod.graph.nodes:
+            if n.op == "placeholder" and n.target == "y":
+                yy = n
+            elif yy is not None and n.op == "call_function":
+                yy.prepend(n)
+                break
+
+        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
+
+        self.assertTrue(mod_folded.const_subgraph_module is None)
+        # Now run both folded and non-folded to check results equal.
+        in_x = torch.tensor([[-0.45]])
+        in_y = torch.tensor([[0.45]])
+        base_result = mod(in_x, in_y)
+        fold_result = mod_folded(in_x, in_y)
+        self.assertTrue(torch.equal(fold_result, base_result))
+
     def test_const_fold_noop(self):
         r"""
         Check that a graph with no constant folding is handled correctly.
index 2ecf88b..b1446ef 100644 (file)
@@ -197,8 +197,8 @@ def split_const_subgraphs(
     for node in split.submod_1.graph.nodes:
         if node.op != "placeholder":
             continue
-        is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys()
-        is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys()
+        is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name
+        is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name
         if not is_folded_attr and not is_unfolded_attr:
             ph_idx += 1
             continue
@@ -208,6 +208,7 @@ def split_const_subgraphs(
             if is_folded_attr
             else submod_1_input_idx_to_unfolded_attr_name[ph_idx]
         )
+
         if is_folded_attr:
             assert not hasattr(mod_traced, const_output_name)
             # Use a dummy param, which will be overwritten when we run const folding.