[FX} Add torch.ops.profiler._record_function_{enter,exit} as stateful ops for DCE...
authorJames Reed <jamesreed@fb.com>
Fri, 17 Sep 2021 03:31:03 +0000 (20:31 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 17 Sep 2021 04:31:54 +0000 (21:31 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65180

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D31007115

Pulled By: jamesr66a

fbshipit-source-id: 823b15db712a382a4f2a4fd409983d47bc067150

test/test_fx.py
torch/fx/node.py

index 722f5ce..97baa43 100644 (file)
@@ -2456,6 +2456,27 @@ class TestFX(JitTestCase):
         self.assertEqual(27, traced(2))
         self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
 
+    def test_profiler_ranges_side_effect(self):
+        g = torch.fx.Graph()
+        handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',))
+        g.call_function(torch.ops.profiler._record_function_exit, (handle,))
+        g.output(None)
+
+        found_targets = {}
+        for node in g.nodes:
+            if node.op == 'call_function':
+                found_targets.setdefault(node.target)
+        self.assertEqual(
+            found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
+
+        g.eliminate_dead_code()
+        found_targets = {}
+        for node in g.nodes:
+            if node.op == 'call_function':
+                found_targets.setdefault(node.target)
+        self.assertEqual(
+            found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
+
     def test_ast_rewriter_wrapped_via_decorator(self):
         class F(torch.nn.Module):
             def forward(self, x):
index 61dfba7..7ad224c 100644 (file)
@@ -24,7 +24,9 @@ Argument = Optional[Union[
     BaseArgumentTypes
 ]]
 
-_side_effectful_functions: Set[Callable] = {torch._assert}
+_side_effectful_functions: Set[Callable] = {
+    torch._assert, torch.ops.profiler._record_function_enter,
+    torch.ops.profiler._record_function_exit}
 
 # this is fixed on master, WAR for 1.5
 def _find_module_of_method(orig_method: Callable[..., Any]) -> str: