Add a record scope around autograd::engine::evaluate_function (#63619)
authorRohan Varma <rvarm1@fb.com>
Wed, 1 Sep 2021 19:28:23 +0000 (12:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 19:32:30 +0000 (12:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63619

Adds a RECORD_FUNCTION with the function that is being valuate as part
of backwards execution. This has been useful in picking up some operations
in the backwards pass that otherwise would not show up, for example custom cpp
functions that use custom C++ code.
ghstack-source-id: 137041723

Test Plan:
CI

benchmark:
buck run mode/opt //scripts/rvarm1/ddp:bench

Reviewed By: albanD

Differential Revision: D30439492

fbshipit-source-id: 955917770cdf2a2edb0303223ace710b668ba388

test/test_autograd.py
torch/csrc/autograd/engine.cpp

index 364d488..8b3c8bd 100644 (file)
@@ -3005,6 +3005,9 @@ class TestAutograd(TestCase):
         found_bwd_add = found_bwd_sum = False
         found_empty = False
         for e in p.function_events:
+            # Ignore record_function user scope.
+            if "autograd::engine::evaluate_function" in e.name:
+                continue
             if e.name == "aten::add":
                 add_seq_nr = e.sequence_nr
                 self.assertFalse(found_add)
index acd7971..4ea002a 100644 (file)
@@ -419,7 +419,18 @@ auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
           // callbacks.
           GraphTaskGuard guard(local_graph_task);
           NodeGuard ndguard(task.fn_);
-          evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
+          {
+            RECORD_FUNCTION(
+                c10::str(
+                    "autograd::engine::evaluate_function: ",
+                    task.fn_.get()->name()),
+                std::vector<c10::IValue>());
+            evaluate_function(
+                local_graph_task,
+                task.fn_.get(),
+                task.inputs_,
+                local_graph_task->cpu_ready_queue_);
+          }
         } catch (std::exception& e) {
           thread_on_exception(local_graph_task, task.fn_, e);
         }