[Pytorch Profiler] Add debug_handles to KinetoEvent (#62228)
authorKimish Patel <kimishpatel@fb.com>
Sat, 14 Aug 2021 04:37:57 +0000 (21:37 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 14 Aug 2021 04:40:14 +0000 (21:40 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62228

This diff adds debug handles to events and provides a way to use
RECORD_FUNCTIONs that will pass debug_handles down to profiler, which
will record it in the events.

Why add debug_handles?
For pytorch mobile, with lite interpreter, we generate debug handles
that can be used for lazily symbolicate exception traces to model level
stack trace. Similar to the model level stack trace you get in
TorchScript models. The debug_handles also enable getting module
hierarchy for lite interpreter model, support for which was added to
KinetoProfiler in previous diffs.

Followup plan:
1. Enabled scope callbacks such that lite interpreter can use it to
profiler only top level ops.
2. Enable post processing callbacks that take KinetoEvents and populate
module hierarchy using debug handles.

This will let us use KinetoProfiler for lite interpter use cases on
mobile. Aim is to use RAII guard to similarly generate chrome trace for
mobile usecases as well, although only for top level ops.

Test Plan:
test_misc : RecordDebugHandles.Basic

Imported from OSS

Reviewed By: ilia-cher

Differential Revision: D29935899

fbshipit-source-id: 4f06dc411b6b5fe0ffaebdd26d3274c96f8f389b

aten/src/ATen/record_function.h
test/cpp/jit/test_misc.cpp
torch/csrc/autograd/profiler_kineto.cpp
torch/csrc/autograd/profiler_kineto.h

index 2e1a8f1..80c3ca9 100644 (file)
@@ -265,6 +265,16 @@ struct TORCH_API RecordFunction {
     return state_->needs_outputs;
   }
 
+  int64_t debugHandle() const {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called debugHandle() on inactive RecordFunction");
+    return state_->debug_handle_;
+  }
+
+  void setDebugHandle(int64_t debug_handle) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setDebugHandle() on inactive RecordFunction");
+    state_->debug_handle_ = debug_handle;
+  }
+
  private:
 
   // Allows the modification of some internal states for callbacks.
@@ -326,6 +336,12 @@ struct TORCH_API RecordFunction {
     // events can complete in different threads or follow a future-like pattern
     // of use.
     bool is_async_{false};
+
+    // Debug handles are used for lazy annotation of module hierarchy
+    // and callstack.
+    // This is specifically is useful for mobile runtime, where generated
+    // debug handles can be lazily symbolicated using debug information
+    int64_t debug_handle_{-1};
   };
 
   std::unique_ptr<State> state_;
@@ -472,6 +488,26 @@ class TORCH_API RecordFunctionCallback {
   RECORD_FUNCTION_WITH_SCOPE( \
     at::RecordScope::USER_SCOPE, fn, inputs)
 
+// Helper macro to pass in debug handle that is used to
+// post process events
+#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS(  \
+    scope, fn, debug_handle, inputs, ...)           \
+    at::RecordFunction guard(scope);                \
+    if (guard.isActive()) {                         \
+      guard.setDebugHandle(debug_handle);           \
+      if (guard.needsInputs()) {                    \
+        guard.before(fn, inputs, ##__VA_ARGS__);    \
+      } else {                                      \
+        guard.before(fn, ##__VA_ARGS__);            \
+      }                                             \
+    }
+
+// Helper macros to record user_scope events with debug handles
+#define RECORD_USER_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(         \
+    fn, debug_handle, inputs)                                   \
+    RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS(                  \
+        at::RecordScope::USER_SCOPE, fn, debug_handle, inputs)
+
 // Notes:
 //  - two types of callbacks are provided: thread local and global
 //     - thread local callbacks are added/removed only for the given thread
index 8d59edc..5ee8816 100644 (file)
@@ -2468,6 +2468,43 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
   torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
 }
 
+TEST(RecordDebugHandles, Basic) {
+  // Enable the profiler in this thread
+  const std::set<torch::autograd::profiler::ActivityType> activities(
+      {torch::autograd::profiler::ActivityType::CPU});
+  torch::autograd::profiler::prepareProfiler(
+      torch::autograd::profiler::ProfilerConfig(
+          torch::autograd::profiler::ProfilerState::KINETO, false, false),
+      activities);
+  torch::autograd::profiler::enableProfiler(
+      torch::autograd::profiler::ProfilerConfig(
+          torch::autograd::profiler::ProfilerState::KINETO, false, false),
+      activities);
+  {
+    RECORD_USER_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
+    float x{5.9999}, y{2.1212};
+    float z = x / y;
+  }
+  {
+    RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
+    float x{5.9999}, y{2.1212};
+    float z = x / y;
+  }
+  auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
+  const auto& kineto_events = profiler_results_ptr->events();
+  size_t my_events{0};
+  for (const auto& e : kineto_events) {
+    if (e.name() == "my_function") {
+      ASSERT_EQ(e.debugHandle(), 42);
+      my_events++;
+    } else if (e.name() == "not_my_function") {
+      ASSERT_EQ(e.debugHandle(), -1);
+      my_events++;
+    }
+  }
+  ASSERT_EQ(my_events, 2);
+}
+
 TEST(IValueKWargsTest, Basic) {
   const auto text = R"(
     def foo(a : int, b : int, c : int = 4):
index 9995237..e92461a 100644 (file)
@@ -100,7 +100,8 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
           .sequenceNr(ctx->sequenceNr)
           .fwdThreadId(ctx->fwdThreadId)
           .scope(ctx->recFunScope)
-          .setAsync(fn.isAsync());
+          .setAsync(fn.isAsync())
+          .debugHandle(ctx->debug_handle);
       if (ctx->shapes && !ctx->shapes->empty()) {
         kineto_events_.back().shapes(*ctx->shapes);
       }
@@ -306,6 +307,7 @@ void pushProfilingCallbacks() {
           auto ctx_ptr = std::make_unique<KinetoObserverContext>();
           ctx_ptr->correlationId = corr_id;
           ctx_ptr->startThreadId = at::RecordFunction::currentThreadId();
+          ctx_ptr->debug_handle = fn.debugHandle();
 
           if (config.report_input_shapes) {
             ctx_ptr->shapes = inputSizes(fn);
index 8c14ae4..8a878a0 100644 (file)
@@ -44,6 +44,7 @@ struct KinetoObserverContext : public at::ObserverContext {
   c10::optional<std::unordered_map<std::string, c10::IValue>> extraArgs;
   CUDAEventStub cuda_event_start_ = nullptr;
   CUDAEventStub cuda_event_end_ = nullptr;
+  int64_t debug_handle;
 };
 
 struct TORCH_API KinetoEvent {
@@ -162,6 +163,15 @@ struct TORCH_API KinetoEvent {
     return *this;
   }
 
+  KinetoEvent& debugHandle(int64_t debug_handle) {
+    debug_handle_ = debug_handle;
+    return *this;
+  }
+
+  int64_t debugHandle() const {
+    return debug_handle_;
+  }
+
   std::string name() const {
     return name_;
   }
@@ -277,6 +287,7 @@ struct TORCH_API KinetoEvent {
   int64_t device_resource_id_ = 0;
   int64_t nbytes_ = 0;
   bool is_async_{false};
+  int64_t debug_handle_{-1};
 
   CUDAEventStub cuda_event_start_ = nullptr;
   CUDAEventStub cuda_event_end_ = nullptr;