From: James Reed Date: Sat, 30 Mar 2019 00:06:08 +0000 (-0700) Subject: Experimental logging/counters API (#18235) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~542 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=85f36014e2628fe291e94be8e5d156b4e6015afd;p=platform%2Fupstream%2Fpytorch.git Experimental logging/counters API (#18235) Summary: This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API: - `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++ - `torch/jit/_logging.py` specifies the Python API We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API. From the frontend perspective, we can create log events in two ways: 1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter. 2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations. Examples of frontend usage can be found in `test_jit.py TestLogging`. We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18235 Differential Revision: D14545060 Pulled By: jamesr66a fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881 --- diff --git a/.gitignore b/.gitignore index 66eb2df..b67a756 100644 --- a/.gitignore +++ b/.gitignore @@ -203,7 +203,6 @@ docs/dev *.sst *.ldb LOCK -LOG* CURRENT MANIFEST-* diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 4a6bf37..25783d1 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -86,6 +86,8 @@ namespace c10 { _(prim, CreateObject) \ _(prim, SetAttr) \ _(prim, GetAttr) \ + _(prim, AddStatValue) \ + _(prim, TimePoint) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ diff --git a/test/test_jit.py b/test/test_jit.py index 6229f0c..a66f847 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1,6 +1,7 @@ from __future__ import division import torch import torch.jit +import torch.jit._logging import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel as dp @@ -13874,6 +13875,106 @@ class TestClassType(JitTestCase): self.assertEqual(y, f2.y) +class TestLogging(JitTestCase): + def test_bump_numeric_counter(self): + class ModuleThatLogs(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + for i in range(x.size(0)): + x += 1.0 + torch.jit._logging.add_stat_value('foo', 1) + + if bool(x.sum() > 0.0): + torch.jit._logging.add_stat_value('positive', 1) + else: + torch.jit._logging.add_stat_value('negative', 1) + return x + + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + + mtl = ModuleThatLogs() + for i in range(5): + mtl(torch.rand(3, 4, 5)) + + self.assertEqual(logger.get_counter_val('foo'), 15) + self.assertEqual(logger.get_counter_val('positive'), 5) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_trace_numeric_counter(self): + def foo(x): + torch.jit._logging.add_stat_value('foo', 1) + return x + 1.0 + + traced = torch.jit.trace(foo, torch.rand(3, 4)) + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + traced(torch.rand(3, 4)) + + self.assertEqual(logger.get_counter_val('foo'), 1) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_time_measurement_counter(self): + class ModuleThatTimes(torch.jit.ScriptModule): + def forward(self, x): + tp_start = torch.jit._logging.time_point() + for i in range(30): + x += 1.0 + tp_end = torch.jit._logging.time_point() + torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) + return x + + mtm = ModuleThatTimes() + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + mtm(torch.rand(3, 4)) + self.assertGreater(logger.get_counter_val('mytimer'), 0) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_time_measurement_counter_script(self): + class ModuleThatTimes(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + tp_start = torch.jit._logging.time_point() + for i in range(30): + x += 1.0 + tp_end = torch.jit._logging.time_point() + torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) + return x + + mtm = ModuleThatTimes() + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + mtm(torch.rand(3, 4)) + self.assertGreater(logger.get_counter_val('mytimer'), 0) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_counter_aggregation(self): + def foo(x): + for i in range(3): + torch.jit._logging.add_stat_value('foo', 1) + return x + 1.0 + + traced = torch.jit.trace(foo, torch.rand(3, 4)) + logger = torch.jit._logging.LockingLogger() + logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG) + old_logger = torch.jit._logging.set_logger(logger) + try: + traced(torch.rand(3, 4)) + + self.assertEqual(logger.get_counter_val('foo'), 1) + finally: + torch.jit._logging.set_logger(old_logger) + + for test in autograd_method_tests(): add_autograd_test(*test) diff --git a/tools/build_variables.py b/tools/build_variables.py index a293a56..5718599 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -95,6 +95,7 @@ libtorch_sources = [ "torch/csrc/jit/scope.cpp", "torch/csrc/jit/script/compiler.cpp", "torch/csrc/jit/script/edit_distance.cpp", + "torch/csrc/jit/script/logging.cpp", "torch/csrc/jit/script/final_returns.cpp", "torch/csrc/jit/script/schema_type_parser.cpp", "torch/csrc/jit/script/script_type_parser.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index deff903..9c905ff 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -185,6 +185,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp + ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 57768c0..1e993da 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -362,7 +363,10 @@ struct GraphExecutorImpl { optimize(optimize), num_inputs(this->graph->inputs().size()), num_flat_inputs(countFlatInputs(graph)), - num_outputs(this->graph->outputs().size()) {} + num_outputs(this->graph->outputs().size()) { + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); + } // entry point where execution begins void run(Stack& stack) { @@ -373,6 +377,9 @@ struct GraphExecutorImpl { " inputs, but got only ", stack.size()); + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); + if (tracer::isTracing()) { return runTraced(stack); } @@ -441,10 +448,15 @@ struct GraphExecutorImpl { { std::lock_guard lock(compile_mutex); auto it = plan_cache.find(spec); - if (it != plan_cache.end()) + if (it != plan_cache.end()) { + logging::getLogger()->addStatValue( + logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0); return it->second; + } auto plan = compileSpec(spec); auto r = plan_cache.emplace(std::move(spec), std::move(plan)); + logging::getLogger()->addStatValue( + logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0); return r.first->second; } } diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 87e69b5..eda0242 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -1,19 +1,20 @@ #include +#include +#include +#include #include #include #include #include #include #include -#include #include #include #include -#include #include #include -#include +#include #include #include diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 3a8e292..b0ef049 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -845,6 +845,8 @@ bool Node::hasSideEffects() const { case prim::RaiseException: case prim::SetAttr: case aten::warn: + case prim::AddStatValue: + case prim::TimePoint: return true; } return false; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 7ae656a..9e22d24 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -887,7 +888,46 @@ RegisterOperators reg( userObj->setSlot(slot, std::move(v)); return 0; }; - })}); + }) + }); + +RegisterOperators logging_operators({ + Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) { + auto val = pop(stack).toInt(); + auto key = pop(stack).toString(); + + auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()"); + // TODO: remove this custom tracing code once the custom op bugfix lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + node->addInput(insertConstant(*graph, key)); + tracer::addInputs(node, "val", val); + graph->insertNode(node); + } + torch::jit::logging::getLogger()->addStatValue(*key, val); + return 0; + }), + Operator("prim::TimePoint() -> int", [](Stack& stack) { + auto schema = parseSchema("prim::TimePoint() -> int"); + Node* node = nullptr; + // TODO: remove this custom tracing code once the custom op bugfix lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + graph->insertNode(node); + } + auto output = autograd::profiler::getTime(); + push(stack, output); + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, output); + } + return 0; + }) +}); + // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index f4d1c89..8175f35 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -16,7 +16,9 @@ #include #include #include +#include #include +#include #include @@ -27,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -1101,6 +1104,29 @@ void initJitScriptBindings(PyObject* module) { .def("run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }); + + m.def("_logging_set_logger", [](logging::LoggerBase* logger) { + return logging::setLogger(logger); + }, py::return_value_policy::reference); + py::class_>( + m, "LoggerBase"); + py::enum_(m, "AggregationType") + .value("SUM", logging::LockingLogger::AggregationType::SUM) + .value("AVG", logging::LockingLogger::AggregationType::AVG) + .export_values(); + py::class_< + logging::LockingLogger, + logging::LoggerBase, + std::shared_ptr>(m, "LockingLogger") + .def(py::init<>()) + .def("set_aggregation_type", &logging::LockingLogger::setAggregationType) + .def("get_counter_val", &logging::LockingLogger::getCounterValue); + py::class_< + logging::NoopLogger, + logging::LoggerBase, + std::shared_ptr>(m, "NoopLogger") + .def(py::init<>()); + } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/script/logging.cpp b/torch/csrc/jit/script/logging.cpp new file mode 100644 index 0000000..48407cc --- /dev/null +++ b/torch/csrc/jit/script/logging.cpp @@ -0,0 +1,73 @@ +#include "torch/csrc/jit/script/logging.h" + +#include +#include +#include + +namespace torch { +namespace jit { +namespace logging { + +// TODO: multi-scale histogram for this thing + +void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) { + std::unique_lock lk(m); + auto& raw_counter = raw_counters[stat_name]; + raw_counter.sum += val; + raw_counter.count++; +} + +TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const { + std::unique_lock lk(m); + if (!raw_counters.count(name)) { + return 0; + } + AggregationType type = agg_types.count(name) ? agg_types.at(name) + : AggregationType::SUM; + const auto &raw_counter = raw_counters.at(name); + switch (type) { + case AggregationType::SUM: { + return raw_counter.sum; + } break; + case AggregationType::AVG: { + return raw_counter.sum / raw_counter.count; + } break; + } + throw std::runtime_error("Unknown aggregation type!"); +} + +void LockingLogger::setAggregationType( + const std::string& stat_name, + AggregationType type) { + agg_types[stat_name] = type; +} + + +std::atomic global_logger{new NoopLogger()}; + +LoggerBase* getLogger() { + return global_logger.load(); +} + +LoggerBase *setLogger(LoggerBase* logger) { + LoggerBase *previous = global_logger.load(); + while (!global_logger.compare_exchange_strong(previous, logger)) { + previous = global_logger.load(); + } + return previous; +} + +JITTimePoint timePoint() { + return JITTimePoint{std::chrono::high_resolution_clock::now()}; +} + +void recordDurationSince(const std::string& name, JITTimePoint tp) { + auto end = std::chrono::high_resolution_clock::now(); + // Measurement in microseconds. + auto seconds = std::chrono::duration(end - tp.point).count() * 1e9; + logging::getLogger()->addStatValue(name, seconds); +} + +} // namespace logging +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/logging.h b/torch/csrc/jit/script/logging.h new file mode 100644 index 0000000..60d1bc3 --- /dev/null +++ b/torch/csrc/jit/script/logging.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace logging { + +class LoggerBase { + public: + TORCH_API virtual void addStatValue( + const std::string& stat_name, + int64_t val) = 0; + virtual ~LoggerBase() {} +}; + +TORCH_API LoggerBase* getLogger(); +TORCH_API LoggerBase* setLogger(LoggerBase* logger); + +// No-op logger. This is the default and is meant to incur almost no runtime +// overhead. + +class NoopLogger : public LoggerBase { + public: + void addStatValue(const std::string& stat_name, int64_t val) override {} + ~NoopLogger() {} +}; + +// Trivial locking logger. Pass in an instance of this to setLogger() to use it. +// This keeps track of the sum of all statistics. +// +// NOTE: this is not written in a scalable way and should probably only be used +// in the single-threaded case or for testing. +class LockingLogger : public LoggerBase { + public: + TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override; + TORCH_API virtual int64_t getCounterValue(const std::string& name) const; + enum class AggregationType { SUM, AVG }; + TORCH_API void setAggregationType( + const std::string& stat_name, + AggregationType type); + ~LockingLogger() {} + + private: + mutable std::mutex m; + struct RawCounter { + RawCounter() : sum(0), count(0) {} + int64_t sum; + size_t count; + }; + std::unordered_map raw_counters; + std::unordered_map agg_types; +}; + +// Make this struct so the timer internals are opaque to the user. +struct JITTimePoint { + std::chrono::time_point point; +}; + +TORCH_API JITTimePoint timePoint(); +TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp); + +namespace runtime_counters { +constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED = + "pytorch_runtime.graph_executors_constructed"; +constexpr const char* GRAPH_EXECUTOR_INVOCATIONS = + "pytorch_runtime.graph_executor_invocations"; +constexpr const char* EXECUTION_PLAN_CACHE_HIT = + "pytorch_runtime.execution_plan_cache_hit"; +constexpr const char* EXECUTION_PLAN_CACHE_MISS = + "pytorch_runtime.execution_plan_cache_miss"; + +inline std::vector allRuntimeCounters() { + return {GRAPH_EXECUTORS_CONSTRUCTED, + GRAPH_EXECUTOR_INVOCATIONS, + EXECUTION_PLAN_CACHE_HIT, + EXECUTION_PLAN_CACHE_MISS}; +} + +} // namespace runtime_counters + +} // namespace logging +} // namespace jit +} // namespace torch diff --git a/torch/jit/_logging.py b/torch/jit/_logging.py new file mode 100644 index 0000000..497c342 --- /dev/null +++ b/torch/jit/_logging.py @@ -0,0 +1,10 @@ +import torch + +add_stat_value = torch.ops.prim.AddStatValue + +set_logger = torch._C._logging_set_logger +LockingLogger = torch._C.LockingLogger +AggregationType = torch._C.AggregationType +NoopLogger = torch._C.NoopLogger + +time_point = torch.ops.prim.TimePoint