def a_non_torch_leaf(a, b):
return a + b
+# Used for test_autowrap_function. Autowrapped functions need to be global
+def fx_int(x: float) -> int:
+ return int(x)
+
+def fx_int_x2(x: float) -> int:
+ return int(x) * 2
+
# used in test_pytree. It's all the way out here because pickling a GraphModule
# that uses Point errors out if Point is local to the function
Point = namedtuple('Point', ['x', 'y'])
-
# Test wrap() passing both a function name as well as a function
# directly
def a_lifted_leaf(a, b):
traced = torch.fx.symbolic_trace(IHaveATensorConstant())
torch.jit.script(traced)
+ def test_autowrap_functions(self):
+ class AutowrapFnTest(torch.nn.Module):
+ def forward(self, x):
+ return fx_int(x.shape[0] / 2)
+
+ class AutowrapFnTest2(torch.nn.Module):
+ def forward(self, x):
+ return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
+
+ # Check function(s) are wrapped
+ # `int` would normally throw a TypeError as argument can't be `Proxy`
+ tracer = Tracer(autowrap_functions=(fx_int,))
+ graph = tracer.trace(AutowrapFnTest())
+ traced = GraphModule(tracer.root, graph, 'test')
+ tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
+ tracer_2.trace(AutowrapFnTest2())
+
+ # Test scriptability
+ traced_scripted = torch.jit.script(traced)
+ self.assertEqual(traced_scripted(torch.rand(4)), 2)
+
def test_torch_fx_len(self):
class FXLenTest(torch.nn.Module):
def forward(self, x):
# documentation. We need it so that Sphinx doesn't leak `math`s path from the
# build environment (e.g. `<module 'math' from '/leaked/path').
- """Tracer(autowrap_modules=(math,), enable_cpatching=False)
+ """Tracer(autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False)
``Tracer`` is the class that implements the symbolic tracing functionality
of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
in the docstrings of the methods on this class.
"""
def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ),
+ autowrap_functions: Tuple[Callable, ...] = (),
enable_cpatching: bool = False,
param_shapes_constant: bool = False) -> None:
# This method's signature is overridden by the first line of this class'
Args:
- autowrap_modules (List[ModuleType]): defaults to `[math]`,
+ autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
Python modules whose functions should be wrapped automatically
without needing to use fx.wrap().
+ autowrap_function (Tuple[Callable, ...]): defaults to `()`,
+ Python functions that should be wrapped automatically without
+ needing to use fx.wrap().
+
enable_cpatching (bool): defaults to `False`,
Allows you to enable/disable monkeypatching of torch functions at the
C-level (which captures functins like randn).
self._autowrap_function_ids: Set[int] = {
id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
if not name.startswith("_") and callable(value)}
+ self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
# Python modules to apply autowrap to at the start, in addition to
# modules we see while tracing