fold_result = mod_folded(in_x, in_y)
base_result = mod(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
+
+ def test_const_fold_submod_hierarchy(self):
+ r"""
+ Perform constant folding conversion, from original mod to split constant folding
+ module where one of the folded attrs comes from a submod deeper in the hierarchy
+ of the base module.
+ """
+
+ class TracedThroughModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.internal_attr = torch.nn.Parameter(torch.randn(2, 3))
+
+ def forward(self):
+ return self.internal_attr
+
+ class ConstFoldTestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.my_mod = TracedThroughModule()
+ self.attr = torch.nn.Parameter(torch.randn(2, 3))
+
+ def forward(self, x):
+ return self.attr + self.my_mod() + x
+
+ mod = ConstFoldTestModule()
+ mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
+ self._verify_const_fold_mod(mod_folded)
+
+ # Now run both folded and non-folded to check results equal.
+ in_x = torch.randn(2, 3)
+ fold_result = mod_folded(in_x)
+ base_result = mod(in_x)
+ self.assertTrue(torch.equal(fold_result, base_result))
import operator
-from typing import Dict, Set, List, Optional
+from typing import Dict, Set, List, Optional, Union
import torch.fx
from torch.fx.passes.split_module import split_module
def split_const_subgraphs(
- module: torch.nn.Module,
+ module: Union[torch.nn.Module, torch.fx.GraphModule]
) -> FoldedGraphModule:
"""
Looks through `module` for any nodes that have all constant attribute inputs
attributes on the module prior to running the non-constant portion of the
graph.
"""
- mod_traced = torch.fx.symbolic_trace(module)
+ if not isinstance(module, torch.fx.GraphModule):
+ mod_traced = torch.fx.symbolic_trace(module)
+ else:
+ mod_traced = module
# Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs.
# somehow a priori knowing the attrs that should be passed as args. We can
# unconditionally do this for all placeholders because we know all
# placeholders to submod_0 must be constants accessible via get_attr.
- for node in split.submod_0.graph.nodes:
+ # Note that here we set the split.submod_0.graph into a new root_submod_0 with split
+ # as the root module, because we are fetching attributes directly from the root
+ # module, instead of fetching them from split.submod_0.
+ root_submod_0 = torch.fx.GraphModule(split, split.submod_0.graph)
+ for node in root_submod_0.graph.nodes:
if node.op != "placeholder":
continue
in_node = next(n for n in call_submod_0_args if n.name == node.target)
assert in_node.op == "get_attr"
- with split.submod_0.graph.inserting_before(node):
- node.replace_all_uses_with(split.submod_0.graph.get_attr(in_node.target))
- split.submod_0.graph.erase_node(node)
+ with root_submod_0.graph.inserting_before(node):
+ node.replace_all_uses_with(root_submod_0.graph.get_attr(in_node.target))
+ root_submod_0.graph.erase_node(node)
return FoldedGraphModule(
- mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names
+ mod_traced, split.submod_1.graph, root_submod_0.graph, const_output_names
)