if kineto_available():
self._test_profiler_tracing(True)
+ def test_profiler_fwd_bwd_link(self):
+ with _profile(use_kineto=True) as prof:
+ t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
+ z = torch.add(t1, t2)
+ y = torch.ones(1)
+ loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
+ loss.backward()
+ with TemporaryFileName(mode="w+") as fname:
+ prof.export_chrome_trace(fname)
+ with io.open(fname, 'r') as f:
+ j = json.load(f)
+ events = j["traceEvents"]
+ ts_to_name = {}
+ flow_s_to_ts = {}
+ flow_f_to_ts = {}
+ for e in events:
+ if e["ph"] == "X":
+ ts_to_name[e["ts"]] = e["name"]
+ if "cat" in e and "name" in e and e["cat"] == "forward_backward" and e["name"] == "fwd_bwd":
+ if e["ph"] == "s":
+ flow_s_to_ts[e["id"]] = e["ts"]
+ elif e["ph"] == "f":
+ flow_f_to_ts[e["id"]] = e["ts"]
+ self.assertTrue(len(flow_s_to_ts) == 2)
+ self.assertTrue(len(flow_f_to_ts) == 2)
+ self.assertTrue(1 in flow_s_to_ts.keys())
+ self.assertTrue(1 in flow_f_to_ts.keys())
+ self.assertTrue(2 in flow_s_to_ts.keys())
+ self.assertTrue(2 in flow_f_to_ts.keys())
+ s_ts_1 = flow_s_to_ts[1]
+ f_ts_1 = flow_f_to_ts[1]
+ s_ts_2 = flow_s_to_ts[2]
+ f_ts_2 = flow_f_to_ts[2]
+ self.assertTrue(all([ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]]))
+ self.assertTrue(ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits")
+ self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
+
if __name__ == '__main__':
run_tests()
std::string dtypesToStr(const std::vector<std::string>& types);
std::vector<std::string> inputTypes(const at::RecordFunction& fn);
+// Assumption: Total threads number will not exceed 2^16-1, and total ops will not exceed 2^48 -1.
+static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
+ return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
+}
+
struct KinetoThreadLocalState : public ProfilerThreadLocalState {
explicit KinetoThreadLocalState(const ProfilerConfig& config)
: ProfilerThreadLocalState(config) {
void finalizeCPUTrace() {
TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
+ // startThreadId_seqNum to pointer of activity.
+ // Low-16bits of startThreadId and low-48bits seqNum are concatenated into one uint64_t variable as key.
+ std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> tidSeq2activity;
+ uint64_t fwd_bwd_link_id = 1;
+
for (size_t idx = 0; idx < cpu_trace->activities.size(); ++idx) {
auto& kineto_event = kineto_events_[idx];
auto& activity = cpu_trace->activities[idx];
activity.addMetadata(
"Sequence number",
std::to_string(kineto_event.sequenceNr()));
+ generateForwardBackwardLink(kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
+ }
+ }
+ }
+
+ void generateForwardBackwardLink(const KinetoEvent &kineto_event,
+ uint64_t &fwd_bwd_link_id,
+ libkineto::GenericTraceActivity &activity,
+ std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> &tidSeq2activity) {
+ if (kineto_event.fwdThreadId() > 0) {
+ // act is backward op.
+ uint64_t key = getForwardThreadKey(kineto_event.fwdThreadId(), kineto_event.sequenceNr());
+ auto iter = tidSeq2activity.find(key);
+ if (iter != tidSeq2activity.end()) {
+ libkineto::GenericTraceActivity* fwd = iter->second;
+ activity.flow.linkedActivity = fwd; // Only destination side set this, to distinguish with start side.
+ activity.flow.id = fwd->flow.id = fwd_bwd_link_id;
+ activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd;
+ ++fwd_bwd_link_id;
+ }
+ }
+ else if (kineto_event.startThreadId() != 0) {
+ // act is forward op.
+ uint64_t key = getForwardThreadKey(kineto_event.startThreadId(), kineto_event.sequenceNr());
+ // Assumption: Among all ops with same sequence number,
+ // the one with biggest start time is most likely launching backward op.
+ auto iter = tidSeq2activity.find(key);
+ if (iter == tidSeq2activity.end()) {
+ tidSeq2activity[key] = &activity;
+ }
+ else {
+ // Now the sequence number is only incremented on creating a "Node" object for backward pass,
+ // by calling "at::sequence_number::get_and_increment()".
+ // Among all ops with same sequence number, the one with biggest startTime is the one launching backward op.
+ if (activity.startTime >= iter->second->startTime) {
+ tidSeq2activity[key] = &activity;
+ }
}
}
}