Back out "Revert D30327514: [Pytorch lite predictor] Use KinetoEdgeCPUProfiler for...
authorKimish Patel <kimishpatel@fb.com>
Wed, 1 Sep 2021 19:38:39 +0000 (12:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 20:29:35 +0000 (13:29 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64307

Original commit changeset: 0b2aa7c57d08

Restores original changes.
This diff changes the way operator profiling is done in lite predictor
benchmarking binary.
Instead of using custom callbacks it uses KinetoEdgeCPUProfiler to profile
events and then generate operator level metric from it.
Since KinetoEvents do not contain cpu clock time, now we report only wallclock
time.
This unifies various profiling effort that we have for benchmarking purpose. In
production we will still use observer based mechanism, but the advantage of
using kineto profiler is that we get few other things for free, such as:
chrome trace generation.
operator level memory profiling (to be added)
flop counts (to be added)
Furthermore possible we can use python post processing script to parse chrome
trace and generate output similar to torch.profiler. (To be done)

Furthermore removes some tests from test_lite_interpreter.cpp which were testing module hierarchy in debug info. They should be covered by test_mobile_profiler.cpp.

Test Plan:
aibench run
Model without debug info:
https://www.internalfb.com/intern/aibench/details/219598441154763
Model with debug info and --print_module_info true (see Operator summary has now module hierarchy information).
https://www.internalfb.com/intern/aibench/details/617154236292985

Reviewed By: raziel

Differential Revision: D30680354

fbshipit-source-id: b6ba0d59c510c13d13d9935b1d8051cc82ffa4e9

test/cpp/jit/test_lite_interpreter.cpp
tools/build_variables.bzl
torch/csrc/jit/mobile/debug_info.cpp
torch/csrc/jit/mobile/import.cpp
torch/csrc/jit/mobile/interpreter.cpp
torch/csrc/jit/mobile/module.cpp
torch/csrc/jit/mobile/module.h
torch/csrc/jit/mobile/profiler_edge.cpp
torch/csrc/jit/mobile/profiler_edge.h

index 3bd2bec..26100b3 100644 (file)
@@ -456,144 +456,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) {
 }
 
 #if !defined FB_XPLAT_BUILD
-TEST(LiteInterpreterTest, ModuleInfoBasic) {
-  Module m("M");
-  m.define(R"JIT(
-    def forward(self, x):
-      return 2 * x
-  )JIT");
-
-  std::stringstream ss;
-  m._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::unordered_set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, NotSaveModuleInfo) {
-  Module m("M");
-  m.define(R"JIT(
-    def forward(self, x):
-      return x + 5
-  )JIT");
-
-  std::stringstream ss;
-  m._save_for_mobile(ss);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      AT_ASSERT(
-          module_info.empty() ||
-          (module_info.find("debug_handle") != std::string::npos));
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-}
-
-TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return 2 * x + 5
-  )JIT");
-  Module b("B");
-  b.register_module("A0", a);
-  b.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + 1
-  )JIT");
-
-  std::stringstream ss;
-  b._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A0(A)::forward.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return x + 1
-  )JIT");
-  Module b("B");
-  b.define(R"JIT(
-    def forward(self, x):
-      return x + 2
-  )JIT");
-  Module c("C");
-  c.register_module("A0", a);
-  c.register_module("B0", b);
-  c.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + self.B0.forward(x)
-  )JIT");
-
-  std::stringstream ss;
-  c._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
 TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
   auto runtime_bytecode_version = _get_runtime_bytecode_version();
   AT_ASSERT(
@@ -795,187 +657,6 @@ TEST(LiteInterpreterTest, isCompatibleFail) {
   AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
 }
 
-#if !defined FB_XPLAT_BUILD
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-TEST(LiteInterpreterTest, SequentialModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return x + 1
-  )JIT");
-  Module b("B");
-  b.define(R"JIT(
-    def forward(self, x):
-      return x + 2
-  )JIT");
-  Module c("C");
-  c.register_module("A0", a);
-  c.register_module("B0", b);
-  c.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(self.B0.forward(x))
-  )JIT");
-
-  std::stringstream ss;
-  c._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  // class A(nn.Module):
-  //   def __init__(self):
-  //     super(A, self).__init__()
-
-  //   def forward(self, x):
-  //     return x + 1
-
-  // class B(nn.Module):
-  //   def __init__(self):
-  //     super(B, self).__init__()
-
-  //   def forward(self, x):
-  //     return x + 2
-
-  // class C(nn.Module):
-  //   def __init__(self):
-  //     super(C, self).__init__()
-  //     self.A0 = A()
-  //     self.B0 = B()
-
-  //   def forward(self, x):
-  //     return self.A0.forward(self.B0.forward(x))
-
-  AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.prim::Return"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, HierarchyModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return x + 1
-  )JIT");
-  Module b("B");
-  b.register_module("A0", a);
-  b.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + 1
-  )JIT");
-  Module c("C");
-  c.register_module("B0", b);
-  c.define(R"JIT(
-    def forward(self, x):
-      return self.B0.forward(x) + 1
-  )JIT");
-
-  std::stringstream ss;
-  c._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  // There are 3 module information strings here.
-  // "top(C).forward": for the add operator in top.
-  // "top(C).B0(B).forward": for the add operator in B0.
-  // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0.
-  AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.B0(B)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.B0(B)::forward.A0(A)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return x + 5
-  )JIT");
-  Module b("B");
-  b.register_module("A0", a);
-  b.register_module("A1", a);
-  b.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + self.A1.forward(x)
-  )JIT");
-
-  std::stringstream ss;
-  b._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  // class A(nn.Module):
-  //   def __init__(self):
-  //     super(A, self).__init__()
-
-  //   def forward(self, x):
-  //     return x + 5
-
-  // class B(nn.Module):
-  //   def __init__(self):
-  //     super(B, self).__init__()
-  //     self.A0 = A()
-  //     self.A1 = A()
-
-  //   def forward(self, x):
-  //     return self.A0.forward(x) + self.A1.forward(x)
-
-  // There are 3 module information strings here.
-  // "top(B).forward": for the add operator in top.
-  // "top(B).A0(A).forward": for the add operator in A0.
-  // "top(B).A1(A).forward": for the add operator in A1.
-
-  AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A1(A)::forward.aten::add"));
-}
-#endif // !defined(FB_XPLAT_BUILD)
-
 TEST(LiteInterpreterTest, Eval) {
   std::vector<torch::jit::IValue> inputs;
 
index 34846b5..c473157 100644 (file)
@@ -319,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/testing/hooks_for_testing.cpp",
     "torch/csrc/utils/tensor_flatten.cpp",
     "torch/csrc/utils/variadic.cpp",
-] + libtorch_profiler_sources
+]
 
 core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
     "torch/csrc/jit/backends/backend_debug_info.cpp",
@@ -337,7 +337,7 @@ core_sources_full = core_sources_full_mobile + [
     "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp",
 ]
 
-libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources)
+libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources)
 
 # These files are the only ones that are supported on Windows.
 libtorch_distributed_base_sources = [
index 41ce3c6..a75ffe1 100644 (file)
@@ -13,6 +13,12 @@ namespace jit {
 
 namespace {
 
+C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
+    const std::string& debug_handles_string) {
+  return "Debug info for handle(s): " + debug_handles_string +
+      ", was not found.";
+}
+
 std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
     const DebugInfoTuple& source_callstack,
     const std::string& caller_name) {
@@ -152,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo(
     const std::string& top_module_type_name) const {
   const auto it = callstack_ptr_map_.find(debug_handle);
   if (it == callstack_ptr_map_.end()) {
-    return "Module info for handle, " + std::to_string(debug_handle) +
-        ", not found.";
+    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
   }
   return (getStackTraceWithModuleHierarchy(
               {it->second}, "top", top_module_type_name))
@@ -172,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString(
     const std::string& top_module_type_name) const {
   const auto it = callstack_ptr_map_.find(debug_handle);
   if (it == callstack_ptr_map_.end()) {
-    return "Debug info for handle, " + std::to_string(debug_handle) +
-        ", not found.";
+    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
   }
   return (getStackTraceWithModuleHierarchy(
               {it->second}, "top", top_module_type_name))
@@ -208,8 +212,7 @@ std::pair<std::string, std::string> MobileDebugTable::
       debug_handles_string += std::to_string(debug_handle);
     }
     debug_handles_string += "}";
-    debug_handles_string =
-        "Debug info for handles: " + debug_handles_string + ", was not found.";
+    debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
     return {debug_handles_string, debug_handles_string};
   }
   return (getStackTraceWithModuleHierarchy(
index 6a54810..99be225 100644 (file)
@@ -517,12 +517,15 @@ mobile::Module BytecodeDeserializer::deserialize(
   auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements();
 
   c10::optional<std::vector<IValue>> debug_handles;
+  bool has_debug_handles{false};
   if (reader_->hasRecord("mobile_debug_handles.pkl")) {
     debug_handles =
         readArchive("mobile_debug_handles", mcu).toTuple()->elements();
+    has_debug_handles = true;
   }
   parseMethods(bvals, debug_handles, *mcu);
   auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
+  m.setHasDebugHandles(has_debug_handles);
 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
   MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
   m.setDebugTable(std::move(debug_table));
index 02e7c35..ab558cd 100644 (file)
@@ -57,6 +57,9 @@ bool InterpreterState::run(Stack& stack) {
       auto inst_with_handle = code_->instructions_with_handles_.at(pc);
       Instruction inst = inst_with_handle.instruction;
       DebugHandle debug_handle = inst_with_handle.debug_handle;
+      // If no valid debug handle found then just log pc.
+      // This is possible when we did not save debug handles
+      debug_handle = debug_handle == -1 ? pc : debug_handle;
 
       // std::cout << "RUNNING " << pc << " "
       //           << code_->instructions_with_handles_[pc].instruction;
index c04d9f7..c74ca13 100644 (file)
@@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const {
 // We really need to change this part, so in the next step for profiling support
 // for delegates, the first thing will be to rewrite how profiling is done
 // for lite interpreter.
-std::string Module::get_forward_method_debug_info(size_t pc) const {
-  auto debug_handle = find_method("forward")->get_debug_handle(pc);
+std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
   return getDebugTable().getModuleHierarchyInfo(
       debug_handle, getTopModuleTypeName(*this));
index 73637aa..6102aa5 100644 (file)
@@ -78,7 +78,7 @@ class TORCH_API Module {
   }
   const std::vector<at::Tensor> parameters() const;
   const std::map<std::string, at::Tensor> named_parameters() const;
-  std::string get_forward_method_debug_info(size_t pc) const;
+  std::string get_forward_method_debug_info(int64_t debug_handle) const;
   std::string getModuleHierarchy(const int64_t debug_handle) const;
   std::string getCallStack(const int64_t debug_handle) const;
   /// Enables "training" mode.
@@ -115,11 +115,20 @@ class TORCH_API Module {
     return debug_table_;
   }
 
+  void setHasDebugHandles(bool has_debug_handles) {
+    has_debug_handles_ = has_debug_handles;
+  }
+
+  bool hasDebugHandles() const {
+    return has_debug_handles_;
+  }
+
  private:
   c10::intrusive_ptr<c10::ivalue::Object> object_;
   std::unordered_map<std::string, std::string> metadata_;
   std::shared_ptr<CompilationUnit> cu_;
   MobileDebugTable debug_table_;
+  bool has_debug_handles_;
 };
 } // namespace mobile
 } // namespace jit
index bcd5a62..162e43f 100644 (file)
@@ -2,7 +2,6 @@
 #include <string>
 #include <vector>
 
-namespace profiler = torch::autograd::profiler;
 namespace torch {
 namespace jit {
 namespace mobile {
@@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
   if (with_modules || with_stack) {
     auto post_processing = [this, with_stack, with_modules](
                                std::vector<profiler::KinetoEvent>& events) {
+      std::string no_debug_info("Model was not saved with debug information");
       for (auto& e : events) {
         if (with_modules) {
           // Since KinetoEvents's module hierarchy takes vector of strings we
           // just construct a temporary vector using one string element
-          e.moduleHierarchy(std::vector<std::string>(
-              {this->m_.getModuleHierarchy(e.debugHandle())}));
+          if (this->m_.hasDebugHandles()) {
+            e.moduleHierarchy(std::vector<std::string>(
+                {this->m_.getModuleHierarchy(e.debugHandle())}));
+          } else {
+            e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
+          }
         } else if (with_stack) {
           // Since KinetoEvents's stack trace takes vector of strings we just
           // construct a temporary vector using one string element
-          e.stack(std::vector<std::string>(
-              {this->m_.getCallStack(e.debugHandle())}));
+          if (this->m_.hasDebugHandles()) {
+            e.stack(std::vector<std::string>(
+                {this->m_.getCallStack(e.debugHandle())}));
+          } else {
+            e.stack(std::vector<std::string>({no_debug_info}));
+          }
         }
       }
     };
@@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
   trace_file_name_ = fname;
 }
 
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+    disableProfiler() {
+  TORCH_CHECK(
+      !profiler_result_,
+      "KinetoEdgeCPUProfiler already disabled. "
+      "To get list of events use getProfilerResults()");
+  profiler_result_ = profiler::disableProfiler();
+  return profiler_result_;
+}
+
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+    getProfilerResult() {
+  TORCH_CHECK(
+      profiler_result_,
+      "KinetoEdgeCPUProfiler has not been disabled. "
+      "use disableProfiler() API first, which returns the ProfilerResult.");
+  return profiler_result_;
+}
+
 KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
-  profiler::disableProfiler()->save(trace_file_name_);
+  if (!trace_file_name_.empty()) {
+    if (profiler_result_) {
+      profiler_result_->save(trace_file_name_);
+    } else {
+      profiler::disableProfiler()->save(trace_file_name_);
+    }
+  }
 }
 } // namespace mobile
 } // namespace jit
index a245034..ef37e01 100644 (file)
@@ -2,6 +2,7 @@
 #include <torch/csrc/autograd/profiler_kineto.h>
 #include <torch/csrc/jit/mobile/module.h>
 
+namespace profiler = torch::autograd::profiler;
 namespace torch {
 namespace jit {
 namespace mobile {
@@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler {
       const bool with_flops = false,
       const bool with_modules = false);
 
+  const std::unique_ptr<profiler::ProfilerResult>& disableProfiler();
+  const std::unique_ptr<profiler::ProfilerResult>& getProfilerResult();
+
   ~KinetoEdgeCPUProfiler();
 
  private:
@@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler {
    */
   const mobile::Module& m_;
   std::string trace_file_name_;
+  std::unique_ptr<profiler::ProfilerResult> profiler_result_;
 };
 } // namespace mobile
 } // namespace jit