fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
+ def test_const_fold_basic_placeholder_reordered(self):
+ """
+ Test code path where placeholder comes after normal op node in FX
+ """
+ class ConstFoldTestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return x * 2 + y
+
+ mod = ConstFoldTestModule()
+ mod = torch.fx.symbolic_trace(mod)
+ yy = None
+ for n in mod.graph.nodes:
+ if n.op == "placeholder" and n.target == "y":
+ yy = n
+ elif yy is not None and n.op == "call_function":
+ yy.prepend(n)
+ break
+
+ mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
+
+ self.assertTrue(mod_folded.const_subgraph_module is None)
+ # Now run both folded and non-folded to check results equal.
+ in_x = torch.tensor([[-0.45]])
+ in_y = torch.tensor([[0.45]])
+ base_result = mod(in_x, in_y)
+ fold_result = mod_folded(in_x, in_y)
+ self.assertTrue(torch.equal(fold_result, base_result))
+
def test_const_fold_noop(self):
r"""
Check that a graph with no constant folding is handled correctly.
for node in split.submod_1.graph.nodes:
if node.op != "placeholder":
continue
- is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys()
- is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys()
+ is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name
+ is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name
if not is_folded_attr and not is_unfolded_attr:
ph_idx += 1
continue
if is_folded_attr
else submod_1_input_idx_to_unfolded_attr_name[ph_idx]
)
+
if is_folded_attr:
assert not hasattr(mod_traced, const_output_name)
# Use a dummy param, which will be overwritten when we run const folding.