[Pytorch lite predictor] Use KinetoEdgeCPUProfiler for operator profiling. (#63367)
authorKimish Patel <kimishpatel@fb.com>
Tue, 31 Aug 2021 03:53:50 +0000 (20:53 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 03:54:51 +0000 (20:54 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63367

This diff changes the way operator profiling is done in lite predictor
benchmarking binary.
Instead of using custom callbacks it uses KinetoEdgeCPUProfiler to profile
events and then generate operator level metric from it.
Since KinetoEvents do not contain cpu clock time, now we report only wallclock
time.
This unifies various profiling effort that we have for benchmarking purpose. In
production we will still use observer based mechanism, but the advantage of
using kineto profiler is that we get few other things for free, such as:
- chrome trace generation.
- operator level memory profiling (to be added)
- flop counts (to be added)

Furthermore possible we can use python post processing script to parse chrome
trace and generate output similar to torch.profiler. (To be done)

Test Plan:
aibench run
Model without debug info:
https://www.internalfb.com/intern/aibench/details/219598441154763
Model with debug info and `--print_module_info true` (see Operator summary has now module hierarchy information).
https://www.internalfb.com/intern/aibench/details/617154236292985

Reviewed By: raziel

Differential Revision: D30327514

fbshipit-source-id: 3bb2f2daaaedfb04bd6f5d9c91292783f9c4344f

test/cpp/jit/test_lite_interpreter.cpp
tools/build_variables.bzl
torch/csrc/jit/mobile/debug_info.cpp
torch/csrc/jit/mobile/import.cpp
torch/csrc/jit/mobile/interpreter.cpp
torch/csrc/jit/mobile/module.cpp
torch/csrc/jit/mobile/module.h
torch/csrc/jit/mobile/profiler_edge.cpp
torch/csrc/jit/mobile/profiler_edge.h

index 3bd2bec..8fb5fe2 100644 (file)
@@ -455,171 +455,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) {
   AT_ASSERT(str == expected);
 }
 
-#if !defined FB_XPLAT_BUILD
-TEST(LiteInterpreterTest, ModuleInfoBasic) {
-  Module m("M");
-  m.define(R"JIT(
-    def forward(self, x):
-      return 2 * x
-  )JIT");
-
-  std::stringstream ss;
-  m._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::unordered_set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, NotSaveModuleInfo) {
-  Module m("M");
-  m.define(R"JIT(
-    def forward(self, x):
-      return x + 5
-  )JIT");
-
-  std::stringstream ss;
-  m._save_for_mobile(ss);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      AT_ASSERT(
-          module_info.empty() ||
-          (module_info.find("debug_handle") != std::string::npos));
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-}
-
-TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return 2 * x + 5
-  )JIT");
-  Module b("B");
-  b.register_module("A0", a);
-  b.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + 1
-  )JIT");
-
-  std::stringstream ss;
-  b._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(B)::<unknown>.A0(A)::forward.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
-  Module a("A");
-  a.define(R"JIT(
-    def forward(self, x):
-      return x + 1
-  )JIT");
-  Module b("B");
-  b.define(R"JIT(
-    def forward(self, x):
-      return x + 2
-  )JIT");
-  Module c("C");
-  c.register_module("A0", a);
-  c.register_module("B0", b);
-  c.define(R"JIT(
-    def forward(self, x):
-      return self.A0.forward(x) + self.B0.forward(x)
-  )JIT");
-
-  std::stringstream ss;
-  c._save_for_mobile(ss, {}, true);
-  mobile::Module bc = _load_for_mobile(ss);
-
-  std::set<std::string> module_debug_info_set;
-  size_t pc = 0;
-  while (true) {
-    try {
-      std::string module_info = bc.get_forward_method_debug_info(pc);
-      if (!module_info.empty() &&
-          (module_info.find("debug_handle") == std::string::npos)) {
-        module_debug_info_set.insert(module_info);
-      }
-      ++pc;
-    } catch (const std::exception& e) {
-      break;
-    }
-  }
-
-  AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.A0(A)::forward.aten::add"));
-  AT_ASSERT(module_debug_info_set.count(
-      "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
-  auto runtime_bytecode_version = _get_runtime_bytecode_version();
-  AT_ASSERT(
-      runtime_bytecode_version ==
-      caffe2::serialize::kMaxSupportedBytecodeVersion);
-}
-
-/**
- * The test below is disarmed for FB internal xplat builds since
- * BUCK requires us to pass in the script_module_v4.ptl file in
- * as a resource dependency of the build rule for this file, and
- * we would need to access it via the C++ Resources API instead
- * of directly reading from disk (which is what the open source
- * build/run does).
- */
-TEST(LiteInterpreterTest, GetByteCodeVersion) {
-  std::string filePath(__FILE__);
-  auto test_model_file_v4 =
-      filePath.substr(0, filePath.find_last_of("/\\") + 1);
-  test_model_file_v4.append("script_module_v4.ptl");
-
-  auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
-  AT_ASSERT(version_v4 == 4);
-}
-#endif // !defined(FB_XPLAT_BUILD)
-
 namespace {
 
 void compareModelOutput(
index dd89981..e0c43d2 100644 (file)
@@ -319,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/testing/hooks_for_testing.cpp",
     "torch/csrc/utils/tensor_flatten.cpp",
     "torch/csrc/utils/variadic.cpp",
-] + libtorch_profiler_sources
+]
 
 core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
     "torch/csrc/jit/backends/backend_debug_info.cpp",
@@ -337,7 +337,7 @@ core_sources_full = core_sources_full_mobile + [
     "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp",
 ]
 
-libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources)
+libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources)
 
 # These files are the only ones that are supported on Windows.
 libtorch_distributed_base_sources = [
index 41ce3c6..a75ffe1 100644 (file)
@@ -13,6 +13,12 @@ namespace jit {
 
 namespace {
 
+C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
+    const std::string& debug_handles_string) {
+  return "Debug info for handle(s): " + debug_handles_string +
+      ", was not found.";
+}
+
 std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
     const DebugInfoTuple& source_callstack,
     const std::string& caller_name) {
@@ -152,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo(
     const std::string& top_module_type_name) const {
   const auto it = callstack_ptr_map_.find(debug_handle);
   if (it == callstack_ptr_map_.end()) {
-    return "Module info for handle, " + std::to_string(debug_handle) +
-        ", not found.";
+    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
   }
   return (getStackTraceWithModuleHierarchy(
               {it->second}, "top", top_module_type_name))
@@ -172,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString(
     const std::string& top_module_type_name) const {
   const auto it = callstack_ptr_map_.find(debug_handle);
   if (it == callstack_ptr_map_.end()) {
-    return "Debug info for handle, " + std::to_string(debug_handle) +
-        ", not found.";
+    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
   }
   return (getStackTraceWithModuleHierarchy(
               {it->second}, "top", top_module_type_name))
@@ -208,8 +212,7 @@ std::pair<std::string, std::string> MobileDebugTable::
       debug_handles_string += std::to_string(debug_handle);
     }
     debug_handles_string += "}";
-    debug_handles_string =
-        "Debug info for handles: " + debug_handles_string + ", was not found.";
+    debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
     return {debug_handles_string, debug_handles_string};
   }
   return (getStackTraceWithModuleHierarchy(
index 6a54810..99be225 100644 (file)
@@ -517,12 +517,15 @@ mobile::Module BytecodeDeserializer::deserialize(
   auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements();
 
   c10::optional<std::vector<IValue>> debug_handles;
+  bool has_debug_handles{false};
   if (reader_->hasRecord("mobile_debug_handles.pkl")) {
     debug_handles =
         readArchive("mobile_debug_handles", mcu).toTuple()->elements();
+    has_debug_handles = true;
   }
   parseMethods(bvals, debug_handles, *mcu);
   auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
+  m.setHasDebugHandles(has_debug_handles);
 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
   MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
   m.setDebugTable(std::move(debug_table));
index 02e7c35..ab558cd 100644 (file)
@@ -57,6 +57,9 @@ bool InterpreterState::run(Stack& stack) {
       auto inst_with_handle = code_->instructions_with_handles_.at(pc);
       Instruction inst = inst_with_handle.instruction;
       DebugHandle debug_handle = inst_with_handle.debug_handle;
+      // If no valid debug handle found then just log pc.
+      // This is possible when we did not save debug handles
+      debug_handle = debug_handle == -1 ? pc : debug_handle;
 
       // std::cout << "RUNNING " << pc << " "
       //           << code_->instructions_with_handles_[pc].instruction;
index c04d9f7..c74ca13 100644 (file)
@@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const {
 // We really need to change this part, so in the next step for profiling support
 // for delegates, the first thing will be to rewrite how profiling is done
 // for lite interpreter.
-std::string Module::get_forward_method_debug_info(size_t pc) const {
-  auto debug_handle = find_method("forward")->get_debug_handle(pc);
+std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
   return getDebugTable().getModuleHierarchyInfo(
       debug_handle, getTopModuleTypeName(*this));
index 73637aa..6102aa5 100644 (file)
@@ -78,7 +78,7 @@ class TORCH_API Module {
   }
   const std::vector<at::Tensor> parameters() const;
   const std::map<std::string, at::Tensor> named_parameters() const;
-  std::string get_forward_method_debug_info(size_t pc) const;
+  std::string get_forward_method_debug_info(int64_t debug_handle) const;
   std::string getModuleHierarchy(const int64_t debug_handle) const;
   std::string getCallStack(const int64_t debug_handle) const;
   /// Enables "training" mode.
@@ -115,11 +115,20 @@ class TORCH_API Module {
     return debug_table_;
   }
 
+  void setHasDebugHandles(bool has_debug_handles) {
+    has_debug_handles_ = has_debug_handles;
+  }
+
+  bool hasDebugHandles() const {
+    return has_debug_handles_;
+  }
+
  private:
   c10::intrusive_ptr<c10::ivalue::Object> object_;
   std::unordered_map<std::string, std::string> metadata_;
   std::shared_ptr<CompilationUnit> cu_;
   MobileDebugTable debug_table_;
+  bool has_debug_handles_;
 };
 } // namespace mobile
 } // namespace jit
index bcd5a62..162e43f 100644 (file)
@@ -2,7 +2,6 @@
 #include <string>
 #include <vector>
 
-namespace profiler = torch::autograd::profiler;
 namespace torch {
 namespace jit {
 namespace mobile {
@@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
   if (with_modules || with_stack) {
     auto post_processing = [this, with_stack, with_modules](
                                std::vector<profiler::KinetoEvent>& events) {
+      std::string no_debug_info("Model was not saved with debug information");
       for (auto& e : events) {
         if (with_modules) {
           // Since KinetoEvents's module hierarchy takes vector of strings we
           // just construct a temporary vector using one string element
-          e.moduleHierarchy(std::vector<std::string>(
-              {this->m_.getModuleHierarchy(e.debugHandle())}));
+          if (this->m_.hasDebugHandles()) {
+            e.moduleHierarchy(std::vector<std::string>(
+                {this->m_.getModuleHierarchy(e.debugHandle())}));
+          } else {
+            e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
+          }
         } else if (with_stack) {
           // Since KinetoEvents's stack trace takes vector of strings we just
           // construct a temporary vector using one string element
-          e.stack(std::vector<std::string>(
-              {this->m_.getCallStack(e.debugHandle())}));
+          if (this->m_.hasDebugHandles()) {
+            e.stack(std::vector<std::string>(
+                {this->m_.getCallStack(e.debugHandle())}));
+          } else {
+            e.stack(std::vector<std::string>({no_debug_info}));
+          }
         }
       }
     };
@@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
   trace_file_name_ = fname;
 }
 
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+    disableProfiler() {
+  TORCH_CHECK(
+      !profiler_result_,
+      "KinetoEdgeCPUProfiler already disabled. "
+      "To get list of events use getProfilerResults()");
+  profiler_result_ = profiler::disableProfiler();
+  return profiler_result_;
+}
+
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+    getProfilerResult() {
+  TORCH_CHECK(
+      profiler_result_,
+      "KinetoEdgeCPUProfiler has not been disabled. "
+      "use disableProfiler() API first, which returns the ProfilerResult.");
+  return profiler_result_;
+}
+
 KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
-  profiler::disableProfiler()->save(trace_file_name_);
+  if (!trace_file_name_.empty()) {
+    if (profiler_result_) {
+      profiler_result_->save(trace_file_name_);
+    } else {
+      profiler::disableProfiler()->save(trace_file_name_);
+    }
+  }
 }
 } // namespace mobile
 } // namespace jit
index a245034..ef37e01 100644 (file)
@@ -2,6 +2,7 @@
 #include <torch/csrc/autograd/profiler_kineto.h>
 #include <torch/csrc/jit/mobile/module.h>
 
+namespace profiler = torch::autograd::profiler;
 namespace torch {
 namespace jit {
 namespace mobile {
@@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler {
       const bool with_flops = false,
       const bool with_modules = false);
 
+  const std::unique_ptr<profiler::ProfilerResult>& disableProfiler();
+  const std::unique_ptr<profiler::ProfilerResult>& getProfilerResult();
+
   ~KinetoEdgeCPUProfiler();
 
  private:
@@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler {
    */
   const mobile::Module& m_;
   std::string trace_file_name_;
+  std::unique_ptr<profiler::ProfilerResult> profiler_result_;
 };
 } // namespace mobile
 } // namespace jit