From: Jordan Fix Date: Fri, 10 Sep 2021 06:49:22 +0000 (-0700) Subject: [const_fold] Keep around node.meta for replaced folded ops (#64782) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~313 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=be091950d0b1407ef1beab55646c096ea6e75be0;p=platform%2Fupstream%2Fpytorch.git [const_fold] Keep around node.meta for replaced folded ops (#64782) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64782 Previously, get_attrs that were added to the graph did not retain node.meta after folding. Add such support, and improve coverage in general here. Test Plan: Added test coverage. Reviewed By: protonu Differential Revision: D30852704 fbshipit-source-id: ece87a61c69b2e68982964c6adc4dde14dae12c7 --- diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index a05b9ae..67f795b 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -1,6 +1,8 @@ +import operator import unittest import torch +import torch.fx from torch.fx.experimental import const_fold @@ -337,3 +339,56 @@ class TestConstFold(unittest.TestCase): fold_result = mod_folded(in_x) base_result = mod(in_x) self.assertTrue(torch.equal(fold_result, base_result)) + + def test_retain_node_meta(self): + r""" + Perform constant folding conversion, and validate that node meta is retained. + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.nn.Parameter(torch.randn(2, 3)) + + def forward(self, x): + a = self.attr + self.attr + return x - a + + mod = ConstFoldTestModule() + gm = torch.fx.symbolic_trace(mod) + + # Add a count for each node to check after we const fold. + for idx, node in enumerate(gm.graph.nodes): + if node.op != "output": + node.meta["meta_idx"] = idx + + # Pre-folding: + # idx 0: placeholder + # idx 1: get_attr (will no longer be used, hence removed) + # idx 2: add (will be folded into a get_attr) + # idx 3: sub + + gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) + self._verify_const_fold_mod(gm_folded) + + # Post-folding: + # idx 0: placeholder + # idx 2: get_attr (replaced original add; original get_attr was removed) + # idx 3: sub + + # Check the expected indices are still here. + for node in gm_folded.graph.nodes: + if node.op == "placeholder": + self.assertEqual(node.meta["meta_idx"], 0) + elif node.op == "get_attr": + self.assertEqual(node.meta["meta_idx"], 2) + elif node.op == "call_function" and node.target == operator.sub: + self.assertEqual(node.meta["meta_idx"], 3) + else: + self.assertEqual(node.op, "output") + + # Now run both folded and non-folded to check results equal. + in_x = torch.randn(2, 3) + fold_result = gm_folded(in_x) + base_result = mod(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index b1446ef..8469329 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -1,9 +1,9 @@ import operator +import re from typing import Dict, Set, List, Optional, Union import torch.fx from torch.fx.passes.split_module import split_module -import re def _make_tuple(x): @@ -218,7 +218,9 @@ def split_const_subgraphs( torch.nn.Parameter(torch.randn(1)), ) with split.submod_1.graph.inserting_before(node): - node.replace_all_uses_with(split.submod_1.graph.get_attr(const_output_name)) + new_node = split.submod_1.graph.get_attr(const_output_name) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) split.submod_1.graph.erase_node(node) ph_idx += 1