self.assertEqual(27, traced(2))
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
+ def test_profiler_ranges_side_effect(self):
+ g = torch.fx.Graph()
+ handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',))
+ g.call_function(torch.ops.profiler._record_function_exit, (handle,))
+ g.output(None)
+
+ found_targets = {}
+ for node in g.nodes:
+ if node.op == 'call_function':
+ found_targets.setdefault(node.target)
+ self.assertEqual(
+ found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
+
+ g.eliminate_dead_code()
+ found_targets = {}
+ for node in g.nodes:
+ if node.op == 'call_function':
+ found_targets.setdefault(node.target)
+ self.assertEqual(
+ found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
+
def test_ast_rewriter_wrapped_via_decorator(self):
class F(torch.nn.Module):
def forward(self, x):
BaseArgumentTypes
]]
-_side_effectful_functions: Set[Callable] = {torch._assert}
+_side_effectful_functions: Set[Callable] = {
+ torch._assert, torch.ops.profiler._record_function_enter,
+ torch.ops.profiler._record_function_exit}
# this is fixed on master, WAR for 1.5
def _find_module_of_method(orig_method: Callable[..., Any]) -> str: