From 54f2eb6e7e313d84abfe7e1f3781998979732be2 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 13 Aug 2021 21:37:57 -0700 Subject: [PATCH] [Pytorch Profiler] Add support for adding module hierarchy to (#61792) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61792 KinetoEvent This PR adds module hierarchy information to events. What is module hierarchy information attached to events? During profiling a TorchScript module, when events are added, we ask JIT what is the module hierarchy associated with the node being executed. At the time of execution of that node, there might be multiple frames in the stack of interpreter. For each frame, we find corresponding node and the corresponding module hierarchy is queried. Module hierarchy corresponding to the node is associated with node's InlinedCallStack. InlinedCallStack of node tracks the path via which the node is inlined. Thus during the inlining process we annotate module information corresponding to the CallMethod nodes being inlined. With this PR, chrome trace will contain additional metadata: "Module Hierarchy". This can look like this: TOP(ResNet)::forward.SELF(ResNet)::_forward_impl.layer1(Sequential)::forward.0(BasicBlock)::forward.conv1(Conv2d)::forward.SELF(Conv2d)::_conv_forward It contains module instance, type name and the method name in the callstack. Test Plan: test_profiler Imported from OSS Reviewed By: raziel, ilia-cher Differential Revision: D29745442 fbshipit-source-id: dc8dfaf7c5b8ab256ff0b2ef1e5ec265ca366528 --- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 2 +- test/test_profiler.py | 64 ++++++++++++ torch/_C/_autograd.pyi | 3 +- torch/autograd/profiler.py | 13 ++- torch/autograd/profiler_legacy.py | 7 +- torch/csrc/autograd/init.cpp | 8 +- torch/csrc/autograd/profiler_kineto.cpp | 19 +++- torch/csrc/autograd/profiler_kineto.h | 16 +++ torch/csrc/autograd/profiler_legacy.h | 9 +- torch/csrc/jit/runtime/interpreter.cpp | 107 +++++++++++++++++++++ torch/csrc/jit/runtime/interpreter.h | 1 + .../rpc/server_process_global_profiler.py | 1 + torch/profiler/profiler.py | 9 ++ 13 files changed, 245 insertions(+), 14 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 2eed50c..c8ef529 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -177,7 +177,7 @@ class AllreduceNCCLTest : public NCCLTest { // Make sure enabling profile does not make any issue. Note, in single // process multi-device mode we do not expect any events be populated for // collective operations, since profiling for that mode is not supported. - enableProfilerLegacy({ProfilerState::CPU}); + enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU)); auto results = pg_->allreduce(tensors_); disableProfilerLegacy(); return results; diff --git a/test/test_profiler.py b/test/test_profiler.py index 8770639..28d9671 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -319,6 +319,70 @@ class TestProfiler(TestCase): ] ) + @unittest.skipIf(not kineto_available(), "Kineto is required") + def test_module_hierarchy(self): + class A(nn.Module): + def __init__(self): + super(A, self).__init__() + + def my_new_method(self, x): + return x * 3 + + def forward_impl_(self, x, y): + return self.my_new_method(x) + y + + def forward(self, x, y): + y = y - 2 + return self.forward_impl_(x, y) + + class B(nn.Module): + def __init__(self): + super(B, self).__init__() + + def forward(self, x): + return x + 2 + + class C(nn.Module): + def __init__(self): + super(C, self).__init__() + self.A0 = A() + self.B0 = B() + + def call_b(self, x): + return self.B0.forward(x) + + def forward(self, x, y): + return self.A0.forward(x, y) + self.call_b(x) + + model = C() + model = torch.jit.script(model) + input_a = torch.rand(128, 128) + input_b = torch.rand(128, 128) + op_to_module_hierarchy = {} + op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."] + op_to_module_hierarchy["aten::mul"] = [ + "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method."] + op_to_module_hierarchy["aten::add"] = [ + "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.", + "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.", "TOP(C)::forward."] + with TemporaryFileName(mode="w+") as fname: + with profile(activities=[torch.profiler.ProfilerActivity.CPU], with_modules=True,) as prof: + model(input_a, input_b) + prof.export_chrome_trace(fname) + with io.open(fname, 'r') as f: + trace = json.load(f) + assert "traceEvents" in trace + events = trace["traceEvents"] + found_memory_events = False + for evt in events: + assert "name" in evt + if "args" in evt: + op_name = evt["name"] + if "Module Hierarchy" in evt["args"]: + hierarchy = evt["args"]["Module Hierarchy"] + if op_name in op_to_module_hierarchy: + assert hierarchy in op_to_module_hierarchy[op_name] + def test_high_level_trace(self): """Checks that python side high level events are recorded. """ diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index cd9b0da..6468eb5 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -40,7 +40,8 @@ class ProfilerConfig: report_input_shapes: bool, profile_memory: bool, with_stack: bool, - with_flops: bool + with_flops: bool, + with_modules: bool ) -> None: ... ... diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index ae5e090..ab95fdb 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -71,6 +71,13 @@ class profile(object): with_stack (bool, optional): record source information (file and line number) for the ops. + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + use_kineto (bool, optional): experimental, enable profiling with Kineto profiler. use_cpu (bool, optional): profile CPU events; setting to ``False`` requires @@ -118,6 +125,7 @@ class profile(object): with_flops=False, profile_memory=False, with_stack=False, + with_modules=False, use_kineto=False, use_cpu=True): self.enabled: bool = enabled @@ -131,6 +139,7 @@ class profile(object): self.record_shapes |= self.with_flops self.profile_memory = profile_memory self.with_stack = with_stack + self.with_modules = with_modules self.use_cpu = use_cpu self.kineto_results: Optional[_ProfilerResult] = None @@ -165,7 +174,8 @@ class profile(object): self.record_shapes, self.profile_memory, self.with_stack, - self.with_flops) + self.with_flops, + self.with_modules) def __enter__(self): if not self.enabled: @@ -557,6 +567,7 @@ class emit_nvtx(object): self.record_shapes, False, False, + False, False), set() ) diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 623e13a..445decf 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -24,7 +24,8 @@ class profile(object): record_shapes=False, with_flops=False, profile_memory=False, - with_stack=False): + with_stack=False, + with_modules=False): self.enabled: bool = enabled if not self.enabled: return @@ -36,6 +37,7 @@ class profile(object): self.record_shapes |= self.with_flops self.profile_memory = profile_memory self.with_stack = with_stack + self.with_modules = with_modules if self.use_cuda and not torch.cuda.is_available(): warn("CUDA is not available, disabling CUDA profiling") @@ -52,7 +54,8 @@ class profile(object): self.record_shapes, self.profile_memory, self.with_stack, - self.with_flops) + self.with_flops, + self.with_modules) def __enter__(self): if not self.enabled: diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index ffe7e83..dc51241 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -73,7 +73,13 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("CUDA", ActivityType::CUDA); py::class_(m, "ProfilerConfig") - .def(py::init()); + .def(py::init()); py::class_(m, "ProfilerEvent") .def("kind", &LegacyEvent::kindStr) diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index b9af532..3b5b511 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -43,7 +43,7 @@ inline int64_t getTimeUs() { } std::string shapesToStr(const std::vector>& shapes); -std::string stacksToStr(const std::vector& stacks); +std::string stacksToStr(const std::vector& stacks, const char* delim); std::string dtypesToStr(const std::vector& types); std::vector inputTypes(const at::RecordFunction& fn); @@ -110,6 +110,9 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState { if (ctx->stack && !ctx->stack->empty()) { kineto_events_.back().stack(*ctx->stack); } + if (ctx->module_hierarchy) { + kineto_events_.back().moduleHierarchy(*ctx->module_hierarchy); + } if (ctx->extraArgs && !ctx->extraArgs->empty()) { kineto_events_.back().flops(computeFlops(std::string(fn.name().str()), *ctx->extraArgs)); } @@ -228,7 +231,10 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState { activity.addMetadata("Input Dims", shapesToStr(kineto_event.shapes())); } if (kineto_event.hasStack()) { - activity.addMetadata("Call stack", stacksToStr(kineto_event.stack())); + activity.addMetadata("Call stack", stacksToStr(kineto_event.stack(), ";")); + } + if (kineto_event.hasModuleHierarchy()) { + activity.addMetadata("Module Hierarchy", stacksToStr(kineto_event.moduleHierarchy(), ".")); } if (kineto_event.hasTypes()) { activity.addMetadata("Input type", dtypesToStr(kineto_event.dtypes())); @@ -326,6 +332,10 @@ void pushProfilingCallbacks() { } ctx_ptr->stack = callstackStr(cs); } + if (config.with_modules && + fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + ctx_ptr->module_hierarchy = jit::currentModuleHierarchy(); + } #endif if (config.state == ProfilerState::KINETO_GPU_FALLBACK) { try { @@ -416,12 +426,12 @@ std::string dtypesToStr(const std::vector& types) { } } -std::string stacksToStr(const std::vector& stacks) { +std::string stacksToStr(const std::vector& stacks, const char* delim) { std::ostringstream oss; std::transform( stacks.begin(), stacks.end(), - std::ostream_iterator(oss, ";"), + std::ostream_iterator(oss, delim), [](std::string s) -> std::string { #ifdef _WIN32 // replace the windows backslash with forward slash @@ -430,7 +440,6 @@ std::string stacksToStr(const std::vector& stacks) { return s; }); auto rc = oss.str(); - rc.pop_back(); return "\"" + rc + "\""; } diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index a948ad3..8c14ae4 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -1,6 +1,7 @@ #pragma once #include +#include #ifdef USE_KINETO // skip Kineto dependency on mobile @@ -38,6 +39,7 @@ struct KinetoObserverContext : public at::ObserverContext { uint64_t fwdThreadId; uint8_t recFunScope; c10::optional> stack; + c10::optional> module_hierarchy; // Extra arguments for computing op flops c10::optional> extraArgs; CUDAEventStub cuda_event_start_ = nullptr; @@ -147,6 +149,19 @@ struct TORCH_API KinetoEvent { return *this; } + bool hasModuleHierarchy() const { + return module_hierarchy_ != c10::nullopt; + } + + const std::vector& moduleHierarchy() const { + return *module_hierarchy_; + } + + KinetoEvent& moduleHierarchy(const std::vector& module_hierarchy) { + module_hierarchy_ = module_hierarchy; + return *this; + } + std::string name() const { return name_; } @@ -248,6 +263,7 @@ struct TORCH_API KinetoEvent { uint8_t activity_type_ = 0; c10::optional>> shapes_; c10::optional> stack_; + c10::optional> module_hierarchy_; c10::optional> dtypes_; uint64_t flops_ = 0; diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 4c751e8..363a42d 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -408,23 +408,26 @@ enum class C10_API_ENUM ProfilerState { }; struct TORCH_API ProfilerConfig { - ProfilerConfig( + explicit ProfilerConfig( ProfilerState state, bool report_input_shapes = false, bool profile_memory = false, bool with_stack = false, - bool with_flops = false) + bool with_flops = false, + bool with_modules = false) : state(state), report_input_shapes(report_input_shapes), profile_memory(profile_memory), with_stack(with_stack), - with_flops(with_flops) {} + with_flops(with_flops), + with_modules(with_modules) {} ~ProfilerConfig() = default; ProfilerState state; bool report_input_shapes; bool profile_memory; bool with_stack; bool with_flops; + bool with_modules; // Returns IValues corresponding to ProfilerConfig struct, to be used for // serialization. diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 79e6267..a095e4a 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #ifdef USE_RPC #include @@ -784,6 +785,105 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } public: + // One way to avoid overhead of forming string would be to return + // a vector of frame.function, i.e. CodeImpl* + // This is not exactly clean as it will expose, internal details of + // interpreter. But this way we hold onto graph/node and Function and + // we can create module hierarchy string for each event in autograd + // profiler at the end, when consolidating events. + // At the moment overhead does not seem exhorbitantly large. + // Another option would be return vector of (string, InlinedCallstackPtrs) + // string would contain function name and typename of self + // Format of the returned vector of strings: + // For each frame, the corresponding module name, type and function name + // are in following format: + // (module type):: + // Special keys for module-instance-name: + // - TOP: for top level module + // - SELF: When method/function of the frame is associated with + // previous frame's module instance + // - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out + // - CALL_FUNCTION: call to free function + std::vector moduleHierarchy() const { + std::vector module_function_list; + std::string module_hierarchy("TOP"); + for (size_t i = 0; i < frames.size(); ++i) { + const Frame& frame = frames[i]; + std::string fn_name = frame.function->function_name_; + // For each frame, type of the class with which the function is + // associated, is queried here. And the type name is added to + // module hierarchy. + const auto& g = frame.function->graph_; + std::string g_self_type; + if (g && g->inputs().size() > 0) { + const auto& g_self_type_ptr = + g->inputs()[0]->type()->cast(); + if (g_self_type_ptr) { + g_self_type = g_self_type_ptr->name()->qualifiedName(); + g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1); + } + } + module_hierarchy.append("(") + .append(g_self_type) + .append(")::") + .append(fn_name); + module_function_list.emplace_back(std::move(module_hierarchy)); + + size_t pc = frame.pc; + // CALL nodes have already advanced the pc, so + // undo that to report the call node + if (i + 1 < frames.size()) { + --pc; + } + + Node* node = frame.function->instructions_source_[pc]; + if (node->callstack()) { + for (const auto& p : (*node->callstack())->vec()) { + fn_name = std::get<0>(p)->name(); + const auto& opt_module_info = std::get<2>(p); + if (opt_module_info.has_value()) { + const auto& module_instance_info = opt_module_info.value(); + module_hierarchy = utils::get_module_info(module_instance_info); + module_hierarchy.append("::").append(fn_name); + } else { + // This is likely a call to free function, not associated with + // any class + module_hierarchy = "::"; + module_hierarchy.append(fn_name); + } + module_function_list.emplace_back(std::move(module_hierarchy)); + } + } + + module_hierarchy = std::string(); + // If this node is of type callMethod then the following frame + // will contain the op being executed. + // For such callMethod node, we add the object instance name + // associated with it, since the following frame will not have it. + if (node->kind() == prim::CallMethod) { + std::string class_instance_name; + if (node->input(0)->node()->kind() == prim::GetAttr) { + class_instance_name = node->input(0)->node()->s(attr::name); + } else if ( + node->owningGraph()->inputs().size() > 0 && + node->input(0) == node->owningGraph()->inputs()[0]) { + class_instance_name = "SELF"; + } else { + class_instance_name = "INSTANCE_NAME_UNKNOWN"; + } + module_hierarchy = std::move(class_instance_name); + } else if (node->kind() == prim::CallFunction) { + auto function_constant = node->input(0)->node(); + auto fun_type = + function_constant->output()->type()->expect(); + auto fun_name = fun_type->function()->name(); + module_hierarchy = "CALL_FUNCTION::"; + module_hierarchy.append(fun_name); + } + } + return module_function_list; + } + std::vector callstack() const { std::vector entries; for (const auto i : c10::irange(frames.size())) { @@ -848,6 +948,13 @@ std::vector currentCallstack() { return std::vector(); } +std::vector currentModuleHierarchy() { + if (tls_int_state_ptr_) { + return tls_int_state_ptr_->moduleHierarchy(); + } + return std::vector(); +} + std::ostream& operator<<(std::ostream& out, const Code& code) { out << *code.pImpl->graph_ << "\n"; code.pImpl->dump(out); diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 4728177..80720ea 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -154,6 +154,7 @@ TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext( // current (TLS) TorchScript interpreter callstack TORCH_API std::vector currentCallstack(); +TORCH_API std::vector currentModuleHierarchy(); } // namespace jit } // namespace torch diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 02d633e..4634313 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -117,6 +117,7 @@ class _server_process_global_profile(profile): self.record_shapes, self.profile_memory, False, + False, False) _enable_server_process_global_profiler(profiler_config) return self diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f77405d..20bdaa3 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -116,6 +116,12 @@ class profile(object): with_stack (bool): record source information (file and line number) for the ops. with_flops (bool): use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. use_cuda (bool): .. deprecated:: 1.8.1 use ``activities`` instead. @@ -210,6 +216,7 @@ class profile(object): profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, + with_modules: bool = False, # deprecated: use_cuda: Optional[bool] = None): if activities: @@ -238,6 +245,7 @@ class profile(object): self.with_flops = with_flops self.profile_memory = profile_memory self.with_stack = with_stack + self.with_modules = with_modules self.step_num = 0 self.current_action = self.schedule(self.step_num) self.profiler: Optional[prof.profile] = None @@ -426,6 +434,7 @@ class profile(object): with_flops=self.with_flops, profile_memory=self.profile_memory, with_stack=self.with_stack, + with_modules=self.with_modules, use_kineto=True, ) self.profiler._prepare_trace() -- 2.7.4