From 219ba6575b682a9b61476da041c2220142d20e3b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 17:35:02 -0700 Subject: [PATCH] add autowrap_functions kwarg to fx.Tracer (#62106) Summary: Implements feature request https://github.com/pytorch/pytorch/issues/62021 Test it out with ```python from torch import fx from torch import nn def fx_int(x): return int(x) class MyModule(nn.Module): def forward(self, x): return fx_int(x.shape[0] / 2) tracer = fx.Tracer(autowrap_functions=(fx_int,)) # or remove kwarg to demonstrate symbolic trace error tracer.trace(MyModule()) ``` First time contributor, so please advise if I could have done anything to make lives easier for next time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62106 Reviewed By: SplitInfinity, driazati Differential Revision: D30080834 Pulled By: jamesr66a fbshipit-source-id: 68fadf8c881ea7930e7afd62b642874010fe4903 --- test/test_fx.py | 29 ++++++++++++++++++++++++++++- torch/fx/_symbolic_trace.py | 10 ++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index af6a5b3..2573572 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -68,11 +68,17 @@ class SimpleTest(torch.nn.Module): def a_non_torch_leaf(a, b): return a + b +# Used for test_autowrap_function. Autowrapped functions need to be global +def fx_int(x: float) -> int: + return int(x) + +def fx_int_x2(x: float) -> int: + return int(x) * 2 + # used in test_pytree. It's all the way out here because pickling a GraphModule # that uses Point errors out if Point is local to the function Point = namedtuple('Point', ['x', 'y']) - # Test wrap() passing both a function name as well as a function # directly def a_lifted_leaf(a, b): @@ -857,6 +863,27 @@ class TestFX(JitTestCase): traced = torch.fx.symbolic_trace(IHaveATensorConstant()) torch.jit.script(traced) + def test_autowrap_functions(self): + class AutowrapFnTest(torch.nn.Module): + def forward(self, x): + return fx_int(x.shape[0] / 2) + + class AutowrapFnTest2(torch.nn.Module): + def forward(self, x): + return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2) + + # Check function(s) are wrapped + # `int` would normally throw a TypeError as argument can't be `Proxy` + tracer = Tracer(autowrap_functions=(fx_int,)) + graph = tracer.trace(AutowrapFnTest()) + traced = GraphModule(tracer.root, graph, 'test') + tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) + tracer_2.trace(AutowrapFnTest2()) + + # Test scriptability + traced_scripted = torch.jit.script(traced) + self.assertEqual(traced_scripted(torch.rand(4)), 2) + def test_torch_fx_len(self): class FXLenTest(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 3098bea..25f739e 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -172,7 +172,7 @@ class Tracer(TracerBase): # documentation. We need it so that Sphinx doesn't leak `math`s path from the # build environment (e.g. ` None: # This method's signature is overridden by the first line of this class' @@ -194,10 +195,14 @@ class Tracer(TracerBase): Args: - autowrap_modules (List[ModuleType]): defaults to `[math]`, + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, Python modules whose functions should be wrapped automatically without needing to use fx.wrap(). + autowrap_function (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). + enable_cpatching (bool): defaults to `False`, Allows you to enable/disable monkeypatching of torch functions at the C-level (which captures functins like randn). @@ -220,6 +225,7 @@ class Tracer(TracerBase): self._autowrap_function_ids: Set[int] = { id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) if not name.startswith("_") and callable(value)} + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) # Python modules to apply autowrap to at the start, in addition to # modules we see while tracing -- 2.7.4