From c6505cc3837eb903f98163e40fad638a1cfeb502 Mon Sep 17 00:00:00 2001 From: Patrick Hu Date: Wed, 1 Sep 2021 10:49:39 -0700 Subject: [PATCH] [FX] Fix python code generation for wrapped getattr() with default value (#64271) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64271 Closes #60417 Modified emit_node() in fx/graph.py to generate getattr() call with default value when len(node.args) != 2 instead of accessing the attribute. Added test_torch_fx_getattr() in test/test_fx.py. Test Plan: pytest test/test_fx.py Imported from OSS Reviewed By: jamesr66a Differential Revision: D30671265 fbshipit-source-id: f2db9ea47e0cb247547e200684f715aab006c374 --- test/test_fx.py | 10 ++++++++++ torch/fx/graph.py | 6 ++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index f4e4ab20..5220f67 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -98,6 +98,8 @@ wrap(a_lifted_leaf2) wrap('len') +wrap('getattr') + @wrap def wrapped_via_decorator(a): return a + 1 @@ -942,6 +944,14 @@ class TestFX(JitTestCase): self.assertEqual(traced2(inp), inp + 3.0) self.assertIs(len, builtins.len) + def test_torch_fx_getattr(self): + class FXGetattrTest(torch.nn.Module): + def forward(self, x): + return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) + + traced = symbolic_trace(FXGetattrTest()) + self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) + def test_sqrt(self): class Sqrt1(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 29ffc41..65e93d0 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -949,11 +949,13 @@ class Graph: return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value if global_name == 'getattr' and \ isinstance(node.args, tuple) and \ isinstance(node.args[1], str) and \ - node.args[1].isidentifier(): - # pretty print attribute access + node.args[1].isidentifier() and \ + len(node.args) == 2: body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') return body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') -- 2.7.4