wrap('len')
+wrap('getattr')
+
@wrap
def wrapped_via_decorator(a):
return a + 1
self.assertEqual(traced2(inp), inp + 3.0)
self.assertIs(len, builtins.len)
+ def test_torch_fx_getattr(self):
+ class FXGetattrTest(torch.nn.Module):
+ def forward(self, x):
+ return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
+
+ traced = symbolic_trace(FXGetattrTest())
+ self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
+
def test_sqrt(self):
class Sqrt1(torch.nn.Module):
def forward(self, x):
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
+ # special case for getattr: node.args could be 2-argument or 3-argument
+ # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
- node.args[1].isidentifier():
- # pretty print attribute access
+ node.args[1].isidentifier() and \
+ len(node.args) == 2:
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
return
body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')