correlate forward and backward op (#62553)
authorTeng Gao <tegao@microsoft.com>
Tue, 21 Sep 2021 13:38:37 +0000 (06:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 14:28:29 +0000 (07:28 -0700)
Summary:
Use startThreadId+seqNumber of forward-op and fwdThreadId+seqNumber of backward-op to correlate pair of them.
third_party/kineto should be updated accordingly: https://github.com/pytorch/kineto/pull/372

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62553

Reviewed By: malfet

Differential Revision: D30125728

Pulled By: gdankel

fbshipit-source-id: 9877a54392ba043d0eac56ce5b7bbf244277fa7e

test/test_profiler.py
torch/csrc/autograd/profiler_kineto.cpp

index 25695a8..8b9428e 100644 (file)
@@ -709,5 +709,42 @@ class TestProfiler(TestCase):
         if kineto_available():
             self._test_profiler_tracing(True)
 
+    def test_profiler_fwd_bwd_link(self):
+        with _profile(use_kineto=True) as prof:
+            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
+            z = torch.add(t1, t2)
+            y = torch.ones(1)
+            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
+            loss.backward()
+        with TemporaryFileName(mode="w+") as fname:
+            prof.export_chrome_trace(fname)
+            with io.open(fname, 'r') as f:
+                j = json.load(f)
+                events = j["traceEvents"]
+                ts_to_name = {}
+                flow_s_to_ts = {}
+                flow_f_to_ts = {}
+                for e in events:
+                    if e["ph"] == "X":
+                        ts_to_name[e["ts"]] = e["name"]
+                    if "cat" in e and "name" in e and e["cat"] == "forward_backward" and e["name"] == "fwd_bwd":
+                        if e["ph"] == "s":
+                            flow_s_to_ts[e["id"]] = e["ts"]
+                        elif e["ph"] == "f":
+                            flow_f_to_ts[e["id"]] = e["ts"]
+                self.assertTrue(len(flow_s_to_ts) == 2)
+                self.assertTrue(len(flow_f_to_ts) == 2)
+                self.assertTrue(1 in flow_s_to_ts.keys())
+                self.assertTrue(1 in flow_f_to_ts.keys())
+                self.assertTrue(2 in flow_s_to_ts.keys())
+                self.assertTrue(2 in flow_f_to_ts.keys())
+                s_ts_1 = flow_s_to_ts[1]
+                f_ts_1 = flow_f_to_ts[1]
+                s_ts_2 = flow_s_to_ts[2]
+                f_ts_2 = flow_f_to_ts[2]
+                self.assertTrue(all([ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]]))
+                self.assertTrue(ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits")
+                self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
+
 if __name__ == '__main__':
     run_tests()
index ce96f84..0d63881 100644 (file)
@@ -47,6 +47,11 @@ std::string stacksToStr(const std::vector<std::string>& stacks, const char* deli
 std::string dtypesToStr(const std::vector<std::string>& types);
 std::vector<std::string> inputTypes(const at::RecordFunction& fn);
 
+// Assumption: Total threads number will not exceed 2^16-1, and total ops will not exceed 2^48 -1.
+static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
+  return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
+}
+
 struct KinetoThreadLocalState : public ProfilerThreadLocalState {
   explicit KinetoThreadLocalState(const ProfilerConfig& config)
     : ProfilerThreadLocalState(config) {
@@ -232,6 +237,11 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
 
   void finalizeCPUTrace() {
     TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
+    // startThreadId_seqNum to pointer of activity.
+    // Low-16bits of startThreadId and low-48bits seqNum are concatenated into one uint64_t variable as key.
+    std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> tidSeq2activity;
+    uint64_t fwd_bwd_link_id = 1;
+
     for (size_t idx = 0; idx < cpu_trace->activities.size(); ++idx) {
       auto& kineto_event = kineto_events_[idx];
       auto& activity = cpu_trace->activities[idx];
@@ -258,6 +268,43 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
         activity.addMetadata(
             "Sequence number",
             std::to_string(kineto_event.sequenceNr()));
+        generateForwardBackwardLink(kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
+      }
+    }
+  }
+
+  void generateForwardBackwardLink(const KinetoEvent &kineto_event,
+    uint64_t &fwd_bwd_link_id,
+    libkineto::GenericTraceActivity &activity,
+    std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> &tidSeq2activity) {
+    if (kineto_event.fwdThreadId() > 0) {
+      // act is backward op.
+      uint64_t key = getForwardThreadKey(kineto_event.fwdThreadId(), kineto_event.sequenceNr());
+      auto iter = tidSeq2activity.find(key);
+      if (iter != tidSeq2activity.end()) {
+        libkineto::GenericTraceActivity* fwd = iter->second;
+        activity.flow.linkedActivity = fwd; // Only destination side set this, to distinguish with start side.
+        activity.flow.id = fwd->flow.id = fwd_bwd_link_id;
+        activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd;
+        ++fwd_bwd_link_id;
+      }
+    }
+    else if (kineto_event.startThreadId() != 0) {
+      // act is forward op.
+      uint64_t key = getForwardThreadKey(kineto_event.startThreadId(), kineto_event.sequenceNr());
+      // Assumption: Among all ops with same sequence number,
+      // the one with biggest start time is most likely launching backward op.
+      auto iter = tidSeq2activity.find(key);
+      if (iter == tidSeq2activity.end()) {
+        tidSeq2activity[key] = &activity;
+      }
+      else {
+        // Now the sequence number is only incremented on creating a "Node" object for backward pass,
+        // by calling "at::sequence_number::get_and_increment()".
+        // Among all ops with same sequence number, the one with biggest startTime is the one launching backward op.
+        if (activity.startTime >= iter->second->startTime) {
+          tidSeq2activity[key] = &activity;
+        }
       }
     }
   }