[FX] Fix python code generation for wrapped getattr() with default value (#64271)
authorPatrick Hu <patrickhu@fb.com>
Wed, 1 Sep 2021 17:49:39 +0000 (10:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 18:30:27 +0000 (11:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64271

Closes #60417

Modified emit_node() in fx/graph.py to generate getattr() call with default value when len(node.args) != 2 instead of accessing the attribute.
Added test_torch_fx_getattr() in test/test_fx.py.

Test Plan:
pytest test/test_fx.py

Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D30671265

fbshipit-source-id: f2db9ea47e0cb247547e200684f715aab006c374

test/test_fx.py
torch/fx/graph.py

index f4e4ab2..5220f67 100644 (file)
@@ -98,6 +98,8 @@ wrap(a_lifted_leaf2)
 
 wrap('len')
 
+wrap('getattr')
+
 @wrap
 def wrapped_via_decorator(a):
     return a + 1
@@ -942,6 +944,14 @@ class TestFX(JitTestCase):
         self.assertEqual(traced2(inp), inp + 3.0)
         self.assertIs(len, builtins.len)
 
+    def test_torch_fx_getattr(self):
+        class FXGetattrTest(torch.nn.Module):
+            def forward(self, x):
+                return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
+
+        traced = symbolic_trace(FXGetattrTest())
+        self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
+
     def test_sqrt(self):
         class Sqrt1(torch.nn.Module):
             def forward(self, x):
index 29ffc41..65e93d0 100644 (file)
@@ -949,11 +949,13 @@ class Graph:
                     return
                 qualified_name = _get_qualified_name(node.target)
                 global_name = add_global(qualified_name, node.target)
+                # special case for getattr: node.args could be 2-argument or 3-argument
+                # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
                 if global_name == 'getattr' and \
                    isinstance(node.args, tuple) and \
                    isinstance(node.args[1], str) and \
-                   node.args[1].isidentifier():
-                    # pretty print attribute access
+                   node.args[1].isidentifier() and \
+                   len(node.args) == 2:
                     body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
                     return
                 body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')