From b12150608e7ee3e29bfd0fed6b97891f97513bf1 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Wed, 8 Sep 2021 13:50:46 -0700 Subject: [PATCH] [fx] make const fold code more pythonic (#64451) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64451 No functional change. Test Plan: ``` buck test caffe2/test:fx_const_fold ``` Reviewed By: jfix71, RoshanPAN, houseroad Differential Revision: D30718255 fbshipit-source-id: 95f98561c7f33fcc6c839db68683c85eb152c949 --- test/fx/test_fx_const_fold.py | 31 +++++++++++++++++++++++++++++++ torch/fx/experimental/const_fold.py | 5 +++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index 1da8109..a05b9ae 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -105,6 +105,37 @@ class TestConstFold(unittest.TestCase): fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result)) + def test_const_fold_basic_placeholder_reordered(self): + """ + Test code path where placeholder comes after normal op node in FX + """ + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * 2 + y + + mod = ConstFoldTestModule() + mod = torch.fx.symbolic_trace(mod) + yy = None + for n in mod.graph.nodes: + if n.op == "placeholder" and n.target == "y": + yy = n + elif yy is not None and n.op == "call_function": + yy.prepend(n) + break + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + + self.assertTrue(mod_folded.const_subgraph_module is None) + # Now run both folded and non-folded to check results equal. + in_x = torch.tensor([[-0.45]]) + in_y = torch.tensor([[0.45]]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + def test_const_fold_noop(self): r""" Check that a graph with no constant folding is handled correctly. diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 2ecf88b..b1446ef 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -197,8 +197,8 @@ def split_const_subgraphs( for node in split.submod_1.graph.nodes: if node.op != "placeholder": continue - is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys() - is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys() + is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name + is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name if not is_folded_attr and not is_unfolded_attr: ph_idx += 1 continue @@ -208,6 +208,7 @@ def split_const_subgraphs( if is_folded_attr else submod_1_input_idx_to_unfolded_attr_name[ph_idx] ) + if is_folded_attr: assert not hasattr(mod_traced, const_output_name) # Use a dummy param, which will be overwritten when we run const folding. -- 2.7.4