ASSERT_EQ(my_events, 2);
}
+TEST(RecordDebugHandles, ScopedCallbacks) {
+ // Enable the profiler in this thread
+ torch::autograd::profiler::prepareProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU});
+ torch::autograd::profiler::enableProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU});
+
+ {
+ auto a = torch::rand({128, 128});
+ auto b = torch::rand({128, 128});
+ auto c = a + b;
+ }
+ auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
+ ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
+
+ // Enable the profiler in this thread
+ torch::autograd::profiler::prepareProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU});
+ torch::autograd::profiler::enableProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU},
+ {at::RecordScope::USER_SCOPE});
+ {
+ auto a = torch::rand({128, 128});
+ auto b = torch::rand({128, 128});
+ auto c = a + b;
+ }
+ profiler_results_ptr = torch::autograd::profiler::disableProfiler();
+ ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
+
+ torch::autograd::profiler::prepareProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU});
+ torch::autograd::profiler::enableProfiler(
+ torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::KINETO, false, false),
+ {torch::autograd::profiler::ActivityType::CPU},
+ {at::RecordScope::USER_SCOPE});
+ {
+ RECORD_USER_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
+ auto a = torch::rand({128, 128});
+ auto b = torch::rand({128, 128});
+ auto c = a + b;
+ }
+ {
+ RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
+ auto a = torch::rand({128, 128});
+ auto b = torch::rand({128, 128});
+ auto c = a + b;
+ }
+ profiler_results_ptr = torch::autograd::profiler::disableProfiler();
+ const auto& kineto_events = profiler_results_ptr->events();
+ for (const auto& e : kineto_events) {
+ if (e.name() == "my_function") {
+ ASSERT_EQ(e.debugHandle(), 42);
+ } else if (e.name() == "not_my_function") {
+ ASSERT_EQ(e.debugHandle(), -1);
+ }
+ }
+ ASSERT_TRUE(profiler_results_ptr->events().size() == 2);
+}
+
TEST(IValueKWargsTest, Basic) {
const auto text = R"(
def foo(a : int, b : int, c : int = 4):
return static_cast<KinetoThreadLocalState*>(state);
}
-void pushProfilingCallbacks() {
+void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
}
})
.needsInputs(state_ptr->config().report_input_shapes)
- .needsIds(true));
+ .needsIds(true)
+ .scopes(scopes));
state_ptr->setCallbackHandle(handle);
}
void enableProfiler(
const ProfilerConfig& config,
- const std::set<ActivityType>& activities) {
+ const std::set<ActivityType>& activities,
+ const std::unordered_set<at::RecordScope>& scopes) {
if (config.state != ProfilerState::NVTX) {
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
if (activities.count(ActivityType::CPU) || config.state == ProfilerState::NVTX) {
- pushProfilingCallbacks();
+ pushProfilingCallbacks(scopes);
}
#ifdef USE_KINETO