}
#if !defined FB_XPLAT_BUILD
-TEST(LiteInterpreterTest, ModuleInfoBasic) {
- Module m("M");
- m.define(R"JIT(
- def forward(self, x):
- return 2 * x
- )JIT");
-
- std::stringstream ss;
- m._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::unordered_set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(M)::<unknown>.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, NotSaveModuleInfo) {
- Module m("M");
- m.define(R"JIT(
- def forward(self, x):
- return x + 5
- )JIT");
-
- std::stringstream ss;
- m._save_for_mobile(ss);
- mobile::Module bc = _load_for_mobile(ss);
-
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- AT_ASSERT(
- module_info.empty() ||
- (module_info.find("debug_handle") != std::string::npos));
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-}
-
-TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return 2 * x + 5
- )JIT");
- Module b("B");
- b.register_module("A0", a);
- b.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + 1
- )JIT");
-
- std::stringstream ss;
- b._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A0(A)::forward.aten::mul"));
-}
-
-TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return x + 1
- )JIT");
- Module b("B");
- b.define(R"JIT(
- def forward(self, x):
- return x + 2
- )JIT");
- Module c("C");
- c.register_module("A0", a);
- c.register_module("B0", b);
- c.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + self.B0.forward(x)
- )JIT");
-
- std::stringstream ss;
- c._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
auto runtime_bytecode_version = _get_runtime_bytecode_version();
AT_ASSERT(
AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
}
-#if !defined FB_XPLAT_BUILD
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-TEST(LiteInterpreterTest, SequentialModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return x + 1
- )JIT");
- Module b("B");
- b.define(R"JIT(
- def forward(self, x):
- return x + 2
- )JIT");
- Module c("C");
- c.register_module("A0", a);
- c.register_module("B0", b);
- c.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(self.B0.forward(x))
- )JIT");
-
- std::stringstream ss;
- c._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- // class A(nn.Module):
- // def __init__(self):
- // super(A, self).__init__()
-
- // def forward(self, x):
- // return x + 1
-
- // class B(nn.Module):
- // def __init__(self):
- // super(B, self).__init__()
-
- // def forward(self, x):
- // return x + 2
-
- // class C(nn.Module):
- // def __init__(self):
- // super(C, self).__init__()
- // self.A0 = A()
- // self.B0 = B()
-
- // def forward(self, x):
- // return self.A0.forward(self.B0.forward(x))
-
- AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.prim::Return"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.B0(B)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, HierarchyModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return x + 1
- )JIT");
- Module b("B");
- b.register_module("A0", a);
- b.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + 1
- )JIT");
- Module c("C");
- c.register_module("B0", b);
- c.define(R"JIT(
- def forward(self, x):
- return self.B0.forward(x) + 1
- )JIT");
-
- std::stringstream ss;
- c._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- // There are 3 module information strings here.
- // "top(C).forward": for the add operator in top.
- // "top(C).B0(B).forward": for the add operator in B0.
- // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0.
- AT_ASSERT(module_debug_info_set.count("top(C)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.B0(B)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(C)::<unknown>.B0(B)::forward.A0(A)::forward.aten::add"));
-}
-
-TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) {
- Module a("A");
- a.define(R"JIT(
- def forward(self, x):
- return x + 5
- )JIT");
- Module b("B");
- b.register_module("A0", a);
- b.register_module("A1", a);
- b.define(R"JIT(
- def forward(self, x):
- return self.A0.forward(x) + self.A1.forward(x)
- )JIT");
-
- std::stringstream ss;
- b._save_for_mobile(ss, {}, true);
- mobile::Module bc = _load_for_mobile(ss);
-
- std::set<std::string> module_debug_info_set;
- size_t pc = 0;
- while (true) {
- try {
- std::string module_info = bc.get_forward_method_debug_info(pc);
- if (!module_info.empty() &&
- (module_info.find("debug_handle") == std::string::npos)) {
- module_debug_info_set.insert(module_info);
- }
- ++pc;
- } catch (const std::exception& e) {
- break;
- }
- }
-
- // class A(nn.Module):
- // def __init__(self):
- // super(A, self).__init__()
-
- // def forward(self, x):
- // return x + 5
-
- // class B(nn.Module):
- // def __init__(self):
- // super(B, self).__init__()
- // self.A0 = A()
- // self.A1 = A()
-
- // def forward(self, x):
- // return self.A0.forward(x) + self.A1.forward(x)
-
- // There are 3 module information strings here.
- // "top(B).forward": for the add operator in top.
- // "top(B).A0(A).forward": for the add operator in A0.
- // "top(B).A1(A).forward": for the add operator in A1.
-
- AT_ASSERT(module_debug_info_set.count("top(B)::<unknown>.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A0(A)::forward.aten::add"));
- AT_ASSERT(module_debug_info_set.count(
- "top(B)::<unknown>.A1(A)::forward.aten::add"));
-}
-#endif // !defined(FB_XPLAT_BUILD)
-
TEST(LiteInterpreterTest, Eval) {
std::vector<torch::jit::IValue> inputs;
namespace {
+C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
+ const std::string& debug_handles_string) {
+ return "Debug info for handle(s): " + debug_handles_string +
+ ", was not found.";
+}
+
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
const DebugInfoTuple& source_callstack,
const std::string& caller_name) {
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
- return "Module info for handle, " + std::to_string(debug_handle) +
- ", not found.";
+ return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
const std::string& top_module_type_name) const {
const auto it = callstack_ptr_map_.find(debug_handle);
if (it == callstack_ptr_map_.end()) {
- return "Debug info for handle, " + std::to_string(debug_handle) +
- ", not found.";
+ return debugHandlesNotFoundMessage(std::to_string(debug_handle));
}
return (getStackTraceWithModuleHierarchy(
{it->second}, "top", top_module_type_name))
debug_handles_string += std::to_string(debug_handle);
}
debug_handles_string += "}";
- debug_handles_string =
- "Debug info for handles: " + debug_handles_string + ", was not found.";
+ debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
return {debug_handles_string, debug_handles_string};
}
return (getStackTraceWithModuleHierarchy(
#include <string>
#include <vector>
-namespace profiler = torch::autograd::profiler;
namespace torch {
namespace jit {
namespace mobile {
if (with_modules || with_stack) {
auto post_processing = [this, with_stack, with_modules](
std::vector<profiler::KinetoEvent>& events) {
+ std::string no_debug_info("Model was not saved with debug information");
for (auto& e : events) {
if (with_modules) {
// Since KinetoEvents's module hierarchy takes vector of strings we
// just construct a temporary vector using one string element
- e.moduleHierarchy(std::vector<std::string>(
- {this->m_.getModuleHierarchy(e.debugHandle())}));
+ if (this->m_.hasDebugHandles()) {
+ e.moduleHierarchy(std::vector<std::string>(
+ {this->m_.getModuleHierarchy(e.debugHandle())}));
+ } else {
+ e.moduleHierarchy(std::vector<std::string>({no_debug_info}));
+ }
} else if (with_stack) {
// Since KinetoEvents's stack trace takes vector of strings we just
// construct a temporary vector using one string element
- e.stack(std::vector<std::string>(
- {this->m_.getCallStack(e.debugHandle())}));
+ if (this->m_.hasDebugHandles()) {
+ e.stack(std::vector<std::string>(
+ {this->m_.getCallStack(e.debugHandle())}));
+ } else {
+ e.stack(std::vector<std::string>({no_debug_info}));
+ }
}
}
};
trace_file_name_ = fname;
}
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+ disableProfiler() {
+ TORCH_CHECK(
+ !profiler_result_,
+ "KinetoEdgeCPUProfiler already disabled. "
+ "To get list of events use getProfilerResults()");
+ profiler_result_ = profiler::disableProfiler();
+ return profiler_result_;
+}
+
+const std::unique_ptr<profiler::ProfilerResult>& KinetoEdgeCPUProfiler::
+ getProfilerResult() {
+ TORCH_CHECK(
+ profiler_result_,
+ "KinetoEdgeCPUProfiler has not been disabled. "
+ "use disableProfiler() API first, which returns the ProfilerResult.");
+ return profiler_result_;
+}
+
KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() {
- profiler::disableProfiler()->save(trace_file_name_);
+ if (!trace_file_name_.empty()) {
+ if (profiler_result_) {
+ profiler_result_->save(trace_file_name_);
+ } else {
+ profiler::disableProfiler()->save(trace_file_name_);
+ }
+ }
}
} // namespace mobile
} // namespace jit