AT_ASSERT(str == expected);
}
-#if !defined FB_XPLAT_BUILD
-TEST(LiteInterpreterTest, ModuleInfoBasic) {
- Module m("M");
- m.define(R"JIT(
- def forward(self, x):
- return 2 * x
- )JIT");
-
- std::stringstream ss;
- m._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::unordered_set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, NotSaveModuleInfo) {
- Module m("M");
- m.define(R"JIT(
- def forward(self, x):
- return x + 5
- )JIT");
-
- std::stringstream ss;
- m._save_for_mobile(ss);
- mobile::Module bc = _load_for_mobile(ss);
-
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- AT_ASSERT(
- module_info.empty() ||
- (module_info.find("debug_handle") != std::string::npos));
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-}
-
-TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return 2 * x + 5
- )JIT");
- Module b("B");
- b.register_module("A0", a);
- b.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + 1
- )JIT");
-
- std::stringstream ss;
- b._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A0(A)::forward.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return x + 1
- )JIT");
- Module b("B");
- b.define(R"JIT(
- def forward(self, x):
- return x + 2
- )JIT");
- Module c("C");
- c.register_module("A0", a);
- c.register_module("B0", b);
- c.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + self.B0.forward(x)
- )JIT");
-
- std::stringstream ss;
- c._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
- auto runtime_bytecode_version = _get_runtime_bytecode_version();
- AT_ASSERT(
- runtime_bytecode_version ==
- caffe2::serialize::kMaxSupportedBytecodeVersion);
-}
-
-/**
- * The test below is disarmed for FB internal xplat builds since
- * BUCK requires us to pass in the script_module_v4.ptl file in
- * as a resource dependency of the build rule for this file, and
- * we would need to access it via the C++ Resources API instead
- * of directly reading from disk (which is what the open source
- * build/run does).
- */
-TEST(LiteInterpreterTest, GetByteCodeVersion) {
- std::string filePath(__FILE__);
- auto test_model_file_v4 =
- filePath.substr(0, filePath.find_last_of("/\\") + 1);
- test_model_file_v4.append("script_module_v4.ptl");
-
- auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
- AT_ASSERT(version_v4 == 4);
-}
-#endif // !defined(FB_XPLAT_BUILD)
-
namespace {
void compareModelOutput(
"torch/csrc/jit/testing/hooks_for_testing.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
-] + libtorch_profiler_sources
+]
core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/tensorexpr/external_functions_codegen.cpp",
]
-libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources)
+libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources)
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
namespace {
+C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
+ const std::string& debug_handles_string) {
+ return "Debug info for handle(s): " + debug_handles_string +
+ ", was not found.";
+}
+
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
const DebugInfoTuple& source_callstack,
const std::string& caller_name) {
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
- return "Module info for handle, " + std::to_string(debug_handle) +
- ", not found.";
+ return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
- return "Debug info for handle, " + std::to_string(debug_handle) +
- ", not found.";
+ return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
debug_handles_string += std::to_string(debug_handle);
}
debug_handles_string += "}";
- debug_handles_string =
- "Debug info for handles: " + debug_handles_string + ", was not found.";
+ debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
return {debug_handles_string, debug_handles_string};
}
return (getStackTraceWithModuleHierarchy(
auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements();
c10::optional<std::vector<IValue>> debug_handles;
+ bool has_debug_handles{false};
if (reader_->hasRecord("mobile_debug_handles.pkl")) {
debug_handles =
readArchive("mobile_debug_handles", mcu).toTuple()->elements();
+ has_debug_handles = true;
}
parseMethods(bvals, debug_handles, *mcu);
auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
+ m.setHasDebugHandles(has_debug_handles);
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
m.setDebugTable(std::move(debug_table));
auto inst_with_handle = code_->instructions_with_handles_.at(pc);
Instruction inst = inst_with_handle.instruction;
DebugHandle debug_handle = inst_with_handle.debug_handle;
+ // If no valid debug handle found then just log pc.
+ // This is possible when we did not save debug handles
+ debug_handle = debug_handle == -1 ? pc : debug_handle;
// std::cout << "RUNNING " << pc << " "
// << code_->instructions_with_handles_[pc].instruction;
// We really need to change this part, so in the next step for profiling support
// for delegates, the first thing will be to rewrite how profiling is done
// for lite interpreter.
-std::string Module::get_forward_method_debug_info(size_t pc) const {
- auto debug_handle = find_method("forward")->get_debug_handle(pc);
+std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
return getDebugTable().getModuleHierarchyInfo(
debug_handle, getTopModuleTypeName(*this));
}
const std::vector<at::Tensor> parameters() const;
const std::map<std::string, at::Tensor> named_parameters() const;
- std::string get_forward_method_debug_info(size_t pc) const;
+ std::string get_forward_method_debug_info(int64_t debug_handle) const;
std::string getModuleHierarchy(const int64_t debug_handle) const;
std::string getCallStack(const int64_t debug_handle) const;
/// Enables "training" mode.
return debug_table_;
}
+ void setHasDebugHandles(bool has_debug_handles) {
+ has_debug_handles_ = has_debug_handles;
+ }
+
+ bool hasDebugHandles() const {
+ return has_debug_handles_;
+ }
+
private:
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
+ bool has_debug_handles_;
};
} // namespace mobile
} // namespace jit
#include <string>
#include <vector>
-namespace profiler = torch::autograd::profiler;
namespace torch {
namespace jit {
namespace mobile {
if (with_modules || with_stack) {
auto post_processing = [this, with_stack, with_modules](
std::vector<profiler::KinetoEvent>& events) {
+ std::string no_debug_info("Model was not saved with debug information");
for (auto& e : events) {
if (with_modules) {
// Since KinetoEvents's module hierarchy takes vector of strings we
// just construct a temporary vector using one string element
- e.moduleHierarchy(std::vector<std::string>(
- {this->m_.getModuleHierarchy(e.debugHandle())}));
+ if (this->m_.hasDebugHandles()) {
+ e.moduleHierarchy(std::vector<std::string>(
+ {this->m_.getModuleHierarchy(e.debugHandle())}));
+ } else {
+ e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
+ }
} else if (with_stack) {
// Since KinetoEvents's stack trace takes vector of strings we just
// construct a temporary vector using one string element
- e.stack(std::vector<std::string>(
- {this->m_.getCallStack(e.debugHandle())}));
+ if (this->m_.hasDebugHandles()) {
+ e.stack(std::vector<std::string>(
+ {this->m_.getCallStack(e.debugHandle())}));
+ } else {
+ e.stack(std::vector<std::string>({no_debug_info}));
+ }
}
}
};
trace_file_name_ = fname;
}
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+ disableProfiler() {
+ TORCH_CHECK(
+ !profiler_result_,
+ "KinetoEdgeCPUProfiler already disabled. "
+ "To get list of events use getProfilerResults()");
+ profiler_result_ = profiler::disableProfiler();
+ return profiler_result_;
+}
+
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+ getProfilerResult() {
+ TORCH_CHECK(
+ profiler_result_,
+ "KinetoEdgeCPUProfiler has not been disabled. "
+ "use disableProfiler() API first, which returns the ProfilerResult.");
+ return profiler_result_;
+}
+
KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
- profiler::disableProfiler()->save(trace_file_name_);
+ if (!trace_file_name_.empty()) {
+ if (profiler_result_) {
+ profiler_result_->save(trace_file_name_);
+ } else {
+ profiler::disableProfiler()->save(trace_file_name_);
+ }
+ }
}
} // namespace mobile
} // namespace jit
#include <torch/csrc/autograd/profiler_kineto.h>
#include <torch/csrc/jit/mobile/module.h>
+namespace profiler = torch::autograd::profiler;
namespace torch {
namespace jit {
namespace mobile {
const bool with_flops = false,
const bool with_modules = false);
+ const std::unique_ptr<profiler::ProfilerResult>& disableProfiler();
+ const std::unique_ptr<profiler::ProfilerResult>& getProfilerResult();
+
~KinetoEdgeCPUProfiler();
private:
*/
const mobile::Module& m_;
std::string trace_file_name_;
+ std::unique_ptr<profiler::ProfilerResult> profiler_result_;
};
} // namespace mobile
} // namespace jit