add autowrap_functions kwarg to fx.Tracer (#62106)
authorAlexander Soare <alexander.soare159@gmail.com>
Fri, 13 Aug 2021 00:35:02 +0000 (17:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 00:38:25 +0000 (17:38 -0700)
Summary:
Implements feature request https://github.com/pytorch/pytorch/issues/62021

Test it out with

```python
from torch import fx
from torch import nn

def fx_int(x):
    return int(x)

class MyModule(nn.Module):
    def forward(self, x):
        return fx_int(x.shape[0] / 2)

tracer = fx.Tracer(autowrap_functions=(fx_int,))  # or remove kwarg to demonstrate symbolic trace error
tracer.trace(MyModule())
```

First time contributor, so please advise if I could have done anything to make lives easier for next time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62106

Reviewed By: SplitInfinity, driazati

Differential Revision: D30080834

Pulled By: jamesr66a

fbshipit-source-id: 68fadf8c881ea7930e7afd62b642874010fe4903

test/test_fx.py
torch/fx/_symbolic_trace.py

index af6a5b3..2573572 100644 (file)
@@ -68,11 +68,17 @@ class SimpleTest(torch.nn.Module):
 def a_non_torch_leaf(a, b):
     return a + b
 
+# Used for test_autowrap_function. Autowrapped functions need to be global
+def fx_int(x: float) -> int:
+    return int(x)
+
+def fx_int_x2(x: float) -> int:
+    return int(x) * 2
+
 # used in test_pytree. It's all the way out here because pickling a GraphModule
 # that uses Point errors out if Point is local to the function
 Point = namedtuple('Point', ['x', 'y'])
 
-
 # Test wrap() passing both a function name as well as a function
 # directly
 def a_lifted_leaf(a, b):
@@ -857,6 +863,27 @@ class TestFX(JitTestCase):
         traced = torch.fx.symbolic_trace(IHaveATensorConstant())
         torch.jit.script(traced)
 
+    def test_autowrap_functions(self):
+        class AutowrapFnTest(torch.nn.Module):
+            def forward(self, x):
+                return fx_int(x.shape[0] / 2)
+
+        class AutowrapFnTest2(torch.nn.Module):
+            def forward(self, x):
+                return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
+
+        # Check function(s) are wrapped
+        # `int` would normally throw a TypeError as argument can't be `Proxy`
+        tracer = Tracer(autowrap_functions=(fx_int,))
+        graph = tracer.trace(AutowrapFnTest())
+        traced = GraphModule(tracer.root, graph, 'test')
+        tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
+        tracer_2.trace(AutowrapFnTest2())
+
+        # Test scriptability
+        traced_scripted = torch.jit.script(traced)
+        self.assertEqual(traced_scripted(torch.rand(4)), 2)
+
     def test_torch_fx_len(self):
         class FXLenTest(torch.nn.Module):
             def forward(self, x):
index 3098bea..25f739e 100644 (file)
@@ -172,7 +172,7 @@ class Tracer(TracerBase):
     # documentation. We need it so that Sphinx doesn't leak `math`s path from the
     # build environment (e.g. `<module 'math' from '/leaked/path').
 
-    """Tracer(autowrap_modules=(math,), enable_cpatching=False)
+    """Tracer(autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False)
 
     ``Tracer`` is the class that implements the symbolic tracing functionality
     of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
@@ -183,6 +183,7 @@ class Tracer(TracerBase):
     in the docstrings of the methods on this class.
     """
     def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ),
+                 autowrap_functions: Tuple[Callable, ...] = (),
                  enable_cpatching: bool = False,
                  param_shapes_constant: bool = False) -> None:
         # This method's signature is overridden by the first line of this class'
@@ -194,10 +195,14 @@ class Tracer(TracerBase):
 
         Args:
 
-            autowrap_modules (List[ModuleType]): defaults to `[math]`,
+            autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
                 Python modules whose functions should be wrapped automatically
                 without needing to use fx.wrap().
 
+            autowrap_function (Tuple[Callable, ...]): defaults to `()`,
+                Python functions that should be wrapped automatically without
+                needing to use fx.wrap().
+
             enable_cpatching (bool): defaults to `False`,
                 Allows you to enable/disable monkeypatching of torch functions at the
                 C-level (which captures functins like randn).
@@ -220,6 +225,7 @@ class Tracer(TracerBase):
         self._autowrap_function_ids: Set[int] = {
             id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
             if not name.startswith("_") and callable(value)}
+        self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
 
         # Python modules to apply autowrap to at the start, in addition to
         # modules we see while tracing