From bc9277dca3a40d99147d4a1a3e0160a4a8e91f9f Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 30 Aug 2021 20:53:50 -0700 Subject: [PATCH] [Pytorch lite predictor] Use KinetoEdgeCPUProfiler for operator profiling. (#63367) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63367 This diff changes the way operator profiling is done in lite predictor benchmarking binary. Instead of using custom callbacks it uses KinetoEdgeCPUProfiler to profile events and then generate operator level metric from it. Since KinetoEvents do not contain cpu clock time, now we report only wallclock time. This unifies various profiling effort that we have for benchmarking purpose. In production we will still use observer based mechanism, but the advantage of using kineto profiler is that we get few other things for free, such as: - chrome trace generation. - operator level memory profiling (to be added) - flop counts (to be added) Furthermore possible we can use python post processing script to parse chrome trace and generate output similar to torch.profiler. (To be done) Test Plan: aibench run Model without debug info: https://www.internalfb.com/intern/aibench/details/219598441154763 Model with debug info and `--print_module_info true` (see Operator summary has now module hierarchy information). https://www.internalfb.com/intern/aibench/details/617154236292985 Reviewed By: raziel Differential Revision: D30327514 fbshipit-source-id: 3bb2f2daaaedfb04bd6f5d9c91292783f9c4344f --- test/cpp/jit/test_lite_interpreter.cpp | 165 -------------------------------- tools/build_variables.bzl | 4 +- torch/csrc/jit/mobile/debug_info.cpp | 15 +-- torch/csrc/jit/mobile/import.cpp | 3 + torch/csrc/jit/mobile/interpreter.cpp | 3 + torch/csrc/jit/mobile/module.cpp | 3 +- torch/csrc/jit/mobile/module.h | 11 ++- torch/csrc/jit/mobile/profiler_edge.cpp | 45 +++++++-- torch/csrc/jit/mobile/profiler_edge.h | 5 + 9 files changed, 72 insertions(+), 182 deletions(-) diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 3bd2bec..8fb5fe2 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -455,171 +455,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) { AT_ASSERT(str == expected); } -#if !defined FB_XPLAT_BUILD -TEST(LiteInterpreterTest, ModuleInfoBasic) { - Module m("M"); - m.define(R"JIT( - def forward(self, x): - return 2 * x - )JIT"); - - std::stringstream ss; - m._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::unordered_set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(M)::.aten::mul")); -} - -TEST(LiteInterpreterTest, NotSaveModuleInfo) { - Module m("M"); - m.define(R"JIT( - def forward(self, x): - return x + 5 - )JIT"); - - std::stringstream ss; - m._save_for_mobile(ss); - mobile::Module bc = _load_for_mobile(ss); - - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - AT_ASSERT( - module_info.empty() || - (module_info.find("debug_handle") != std::string::npos)); - ++pc; - } catch (const std::exception& e) { - break; - } - } -} - -TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return 2 * x + 5 - )JIT"); - Module b("B"); - b.register_module("A0", a); - b.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + 1 - )JIT"); - - std::stringstream ss; - b._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(B)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A0(A)::forward.aten::mul")); -} - -TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return x + 1 - )JIT"); - Module b("B"); - b.define(R"JIT( - def forward(self, x): - return x + 2 - )JIT"); - Module c("C"); - c.register_module("A0", a); - c.register_module("B0", b); - c.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + self.B0.forward(x) - )JIT"); - - std::stringstream ss; - c._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(C)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.B0(B)::forward.aten::add")); -} - -TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) { - auto runtime_bytecode_version = _get_runtime_bytecode_version(); - AT_ASSERT( - runtime_bytecode_version == - caffe2::serialize::kMaxSupportedBytecodeVersion); -} - -/** - * The test below is disarmed for FB internal xplat builds since - * BUCK requires us to pass in the script_module_v4.ptl file in - * as a resource dependency of the build rule for this file, and - * we would need to access it via the C++ Resources API instead - * of directly reading from disk (which is what the open source - * build/run does). - */ -TEST(LiteInterpreterTest, GetByteCodeVersion) { - std::string filePath(__FILE__); - auto test_model_file_v4 = - filePath.substr(0, filePath.find_last_of("/\\") + 1); - test_model_file_v4.append("script_module_v4.ptl"); - - auto version_v4 = _get_model_bytecode_version(test_model_file_v4); - AT_ASSERT(version_v4 == 4); -} -#endif // !defined(FB_XPLAT_BUILD) - namespace { void compareModelOutput( diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index dd89981..e0c43d2 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -319,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/testing/hooks_for_testing.cpp", "torch/csrc/utils/tensor_flatten.cpp", "torch/csrc/utils/variadic.cpp", -] + libtorch_profiler_sources +] core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [ "torch/csrc/jit/backends/backend_debug_info.cpp", @@ -337,7 +337,7 @@ core_sources_full = core_sources_full_mobile + [ "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp", ] -libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources) +libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources) # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index 41ce3c6..a75ffe1 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -13,6 +13,12 @@ namespace jit { namespace { +C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage( + const std::string& debug_handles_string) { + return "Debug info for handle(s): " + debug_handles_string + + ", was not found."; +} + std::pair, std::string> getStackTraceWithModuleHierarchy( const DebugInfoTuple& source_callstack, const std::string& caller_name) { @@ -152,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "Module info for handle, " + std::to_string(debug_handle) + - ", not found."; + return debugHandlesNotFoundMessage(std::to_string(debug_handle)); } return (getStackTraceWithModuleHierarchy( {it->second}, "top", top_module_type_name)) @@ -172,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "Debug info for handle, " + std::to_string(debug_handle) + - ", not found."; + return debugHandlesNotFoundMessage(std::to_string(debug_handle)); } return (getStackTraceWithModuleHierarchy( {it->second}, "top", top_module_type_name)) @@ -208,8 +212,7 @@ std::pair MobileDebugTable:: debug_handles_string += std::to_string(debug_handle); } debug_handles_string += "}"; - debug_handles_string = - "Debug info for handles: " + debug_handles_string + ", was not found."; + debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string); return {debug_handles_string, debug_handles_string}; } return (getStackTraceWithModuleHierarchy( diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 6a54810..99be225 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -517,12 +517,15 @@ mobile::Module BytecodeDeserializer::deserialize( auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements(); c10::optional> debug_handles; + bool has_debug_handles{false}; if (reader_->hasRecord("mobile_debug_handles.pkl")) { debug_handles = readArchive("mobile_debug_handles", mcu).toTuple()->elements(); + has_debug_handles = true; } parseMethods(bvals, debug_handles, *mcu); auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu); + m.setHasDebugHandles(has_debug_handles); #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_); m.setDebugTable(std::move(debug_table)); diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 02e7c35..ab558cd 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -57,6 +57,9 @@ bool InterpreterState::run(Stack& stack) { auto inst_with_handle = code_->instructions_with_handles_.at(pc); Instruction inst = inst_with_handle.instruction; DebugHandle debug_handle = inst_with_handle.debug_handle; + // If no valid debug handle found then just log pc. + // This is possible when we did not save debug handles + debug_handle = debug_handle == -1 ? pc : debug_handle; // std::cout << "RUNNING " << pc << " " // << code_->instructions_with_handles_[pc].instruction; diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index c04d9f7..c74ca13 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const { // We really need to change this part, so in the next step for profiling support // for delegates, the first thing will be to rewrite how profiling is done // for lite interpreter. -std::string Module::get_forward_method_debug_info(size_t pc) const { - auto debug_handle = find_method("forward")->get_debug_handle(pc); +std::string Module::get_forward_method_debug_info(int64_t debug_handle) const { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) return getDebugTable().getModuleHierarchyInfo( debug_handle, getTopModuleTypeName(*this)); diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 73637aa..6102aa5 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -78,7 +78,7 @@ class TORCH_API Module { } const std::vector parameters() const; const std::map named_parameters() const; - std::string get_forward_method_debug_info(size_t pc) const; + std::string get_forward_method_debug_info(int64_t debug_handle) const; std::string getModuleHierarchy(const int64_t debug_handle) const; std::string getCallStack(const int64_t debug_handle) const; /// Enables "training" mode. @@ -115,11 +115,20 @@ class TORCH_API Module { return debug_table_; } + void setHasDebugHandles(bool has_debug_handles) { + has_debug_handles_ = has_debug_handles; + } + + bool hasDebugHandles() const { + return has_debug_handles_; + } + private: c10::intrusive_ptr object_; std::unordered_map metadata_; std::shared_ptr cu_; MobileDebugTable debug_table_; + bool has_debug_handles_; }; } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/profiler_edge.cpp b/torch/csrc/jit/mobile/profiler_edge.cpp index bcd5a62..162e43f 100644 --- a/torch/csrc/jit/mobile/profiler_edge.cpp +++ b/torch/csrc/jit/mobile/profiler_edge.cpp @@ -2,7 +2,6 @@ #include #include -namespace profiler = torch::autograd::profiler; namespace torch { namespace jit { namespace mobile { @@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( if (with_modules || with_stack) { auto post_processing = [this, with_stack, with_modules]( std::vector& events) { + std::string no_debug_info("Model was not saved with debug information"); for (auto& e : events) { if (with_modules) { // Since KinetoEvents's module hierarchy takes vector of strings we // just construct a temporary vector using one string element - e.moduleHierarchy(std::vector( - {this->m_.getModuleHierarchy(e.debugHandle())})); + if (this->m_.hasDebugHandles()) { + e.moduleHierarchy(std::vector( + {this->m_.getModuleHierarchy(e.debugHandle())})); + } else { + e.moduleHierarchy(std::vector({no_debug_info})); + } } else if (with_stack) { // Since KinetoEvents's stack trace takes vector of strings we just // construct a temporary vector using one string element - e.stack(std::vector( - {this->m_.getCallStack(e.debugHandle())})); + if (this->m_.hasDebugHandles()) { + e.stack(std::vector( + {this->m_.getCallStack(e.debugHandle())})); + } else { + e.stack(std::vector({no_debug_info})); + } } } }; @@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( trace_file_name_ = fname; } +const std::unique_ptr& KinetoEdgeCPUProfiler:: + disableProfiler() { + TORCH_CHECK( + !profiler_result_, + "KinetoEdgeCPUProfiler already disabled. " + "To get list of events use getProfilerResults()"); + profiler_result_ = profiler::disableProfiler(); + return profiler_result_; +} + +const std::unique_ptr& KinetoEdgeCPUProfiler:: + getProfilerResult() { + TORCH_CHECK( + profiler_result_, + "KinetoEdgeCPUProfiler has not been disabled. " + "use disableProfiler() API first, which returns the ProfilerResult."); + return profiler_result_; +} + KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() { - profiler::disableProfiler()->save(trace_file_name_); + if (!trace_file_name_.empty()) { + if (profiler_result_) { + profiler_result_->save(trace_file_name_); + } else { + profiler::disableProfiler()->save(trace_file_name_); + } + } } } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index a245034..ef37e01 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -2,6 +2,7 @@ #include #include +namespace profiler = torch::autograd::profiler; namespace torch { namespace jit { namespace mobile { @@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler { const bool with_flops = false, const bool with_modules = false); + const std::unique_ptr& disableProfiler(); + const std::unique_ptr& getProfilerResult(); + ~KinetoEdgeCPUProfiler(); private: @@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler { */ const mobile::Module& m_; std::string trace_file_name_; + std::unique_ptr profiler_result_; }; } // namespace mobile } // namespace jit -- 2.7.4