Experimental logging/counters API (#18235)
authorJames Reed <jamesreed@fb.com>
Sat, 30 Mar 2019 00:06:08 +0000 (17:06 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 00:14:03 +0000 (17:14 -0700)
Summary:
This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API:

- `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++
- `torch/jit/_logging.py` specifies the Python API

We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API.

From the frontend perspective, we can create log events in two ways:
1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter.
2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations.

Examples of frontend usage can be found in `test_jit.py TestLogging`.

We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18235

Differential Revision: D14545060

Pulled By: jamesr66a

fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881

13 files changed:
.gitignore
aten/src/ATen/core/interned_strings.h
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/interpreter.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/logging.cpp [new file with mode: 0644]
torch/csrc/jit/script/logging.h [new file with mode: 0644]
torch/jit/_logging.py [new file with mode: 0644]

index 66eb2df..b67a756 100644 (file)
@@ -203,7 +203,6 @@ docs/dev
 *.sst
 *.ldb
 LOCK
-LOG*
 CURRENT
 MANIFEST-*
 
index 4a6bf37..25783d1 100644 (file)
@@ -86,6 +86,8 @@ namespace c10 {
   _(prim, CreateObject)            \
   _(prim, SetAttr)                 \
   _(prim, GetAttr)                 \
+  _(prim, AddStatValue)            \
+  _(prim, TimePoint)               \
   _(aten, append)                  \
   _(aten, item)                    \
   _(aten, format)                  \
index 6229f0c..a66f847 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import division
 import torch
 import torch.jit
+import torch.jit._logging
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.parallel as dp
@@ -13874,6 +13875,106 @@ class TestClassType(JitTestCase):
         self.assertEqual(y, f2.y)
 
 
+class TestLogging(JitTestCase):
+    def test_bump_numeric_counter(self):
+        class ModuleThatLogs(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x):
+                for i in range(x.size(0)):
+                    x += 1.0
+                    torch.jit._logging.add_stat_value('foo', 1)
+
+                if bool(x.sum() > 0.0):
+                    torch.jit._logging.add_stat_value('positive', 1)
+                else:
+                    torch.jit._logging.add_stat_value('negative', 1)
+                return x
+
+        logger = torch.jit._logging.LockingLogger()
+        old_logger = torch.jit._logging.set_logger(logger)
+        try:
+
+            mtl = ModuleThatLogs()
+            for i in range(5):
+                mtl(torch.rand(3, 4, 5))
+
+            self.assertEqual(logger.get_counter_val('foo'), 15)
+            self.assertEqual(logger.get_counter_val('positive'), 5)
+        finally:
+            torch.jit._logging.set_logger(old_logger)
+
+    def test_trace_numeric_counter(self):
+        def foo(x):
+            torch.jit._logging.add_stat_value('foo', 1)
+            return x + 1.0
+
+        traced = torch.jit.trace(foo, torch.rand(3, 4))
+        logger = torch.jit._logging.LockingLogger()
+        old_logger = torch.jit._logging.set_logger(logger)
+        try:
+            traced(torch.rand(3, 4))
+
+            self.assertEqual(logger.get_counter_val('foo'), 1)
+        finally:
+            torch.jit._logging.set_logger(old_logger)
+
+    def test_time_measurement_counter(self):
+        class ModuleThatTimes(torch.jit.ScriptModule):
+            def forward(self, x):
+                tp_start = torch.jit._logging.time_point()
+                for i in range(30):
+                    x += 1.0
+                tp_end = torch.jit._logging.time_point()
+                torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
+                return x
+
+        mtm = ModuleThatTimes()
+        logger = torch.jit._logging.LockingLogger()
+        old_logger = torch.jit._logging.set_logger(logger)
+        try:
+            mtm(torch.rand(3, 4))
+            self.assertGreater(logger.get_counter_val('mytimer'), 0)
+        finally:
+            torch.jit._logging.set_logger(old_logger)
+
+    def test_time_measurement_counter_script(self):
+        class ModuleThatTimes(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x):
+                tp_start = torch.jit._logging.time_point()
+                for i in range(30):
+                    x += 1.0
+                tp_end = torch.jit._logging.time_point()
+                torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
+                return x
+
+        mtm = ModuleThatTimes()
+        logger = torch.jit._logging.LockingLogger()
+        old_logger = torch.jit._logging.set_logger(logger)
+        try:
+            mtm(torch.rand(3, 4))
+            self.assertGreater(logger.get_counter_val('mytimer'), 0)
+        finally:
+            torch.jit._logging.set_logger(old_logger)
+
+    def test_counter_aggregation(self):
+        def foo(x):
+            for i in range(3):
+                torch.jit._logging.add_stat_value('foo', 1)
+            return x + 1.0
+
+        traced = torch.jit.trace(foo, torch.rand(3, 4))
+        logger = torch.jit._logging.LockingLogger()
+        logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
+        old_logger = torch.jit._logging.set_logger(logger)
+        try:
+            traced(torch.rand(3, 4))
+
+            self.assertEqual(logger.get_counter_val('foo'), 1)
+        finally:
+            torch.jit._logging.set_logger(old_logger)
+
+
 for test in autograd_method_tests():
     add_autograd_test(*test)
 
index a293a56..5718599 100644 (file)
@@ -95,6 +95,7 @@ libtorch_sources = [
     "torch/csrc/jit/scope.cpp",
     "torch/csrc/jit/script/compiler.cpp",
     "torch/csrc/jit/script/edit_distance.cpp",
+    "torch/csrc/jit/script/logging.cpp",
     "torch/csrc/jit/script/final_returns.cpp",
     "torch/csrc/jit/script/schema_type_parser.cpp",
     "torch/csrc/jit/script/script_type_parser.cpp",
index deff903..9c905ff 100644 (file)
@@ -185,6 +185,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
   ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
   ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
index 57768c0..1e993da 100644 (file)
@@ -32,6 +32,7 @@
 #include <torch/csrc/autograd/edge.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/logging.h>
 
 #include <cstdint>
 #include <iterator>
@@ -362,7 +363,10 @@ struct GraphExecutorImpl {
         optimize(optimize),
         num_inputs(this->graph->inputs().size()),
         num_flat_inputs(countFlatInputs(graph)),
-        num_outputs(this->graph->outputs().size()) {}
+        num_outputs(this->graph->outputs().size()) {
+    logging::getLogger()->addStatValue(
+        logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
+  }
 
   // entry point where execution begins
   void run(Stack& stack) {
@@ -373,6 +377,9 @@ struct GraphExecutorImpl {
         " inputs, but got only ",
         stack.size());
 
+    logging::getLogger()->addStatValue(
+        logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
+
     if (tracer::isTracing()) {
       return runTraced(stack);
     }
@@ -441,10 +448,15 @@ struct GraphExecutorImpl {
     {
       std::lock_guard<std::mutex> lock(compile_mutex);
       auto it = plan_cache.find(spec);
-      if (it != plan_cache.end())
+      if (it != plan_cache.end()) {
+        logging::getLogger()->addStatValue(
+            logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0);
         return it->second;
+      }
       auto plan = compileSpec(spec);
       auto r = plan_cache.emplace(std::move(spec), std::move(plan));
+      logging::getLogger()->addStatValue(
+          logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0);
       return r.first->second;
     }
   }
index 87e69b5..eda0242 100644 (file)
@@ -1,19 +1,20 @@
 #include <torch/csrc/jit/interpreter.h>
 
+#include <ATen/core/ivalue.h>
+#include <c10/core/thread_pool.h>
+#include <c10/util/Exception.h>
 #include <torch/csrc/autograd/edge.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/generated/variable_factories.h>
 #include <torch/csrc/autograd/grad_mode.h>
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/autograd/variable.h>
-#include <c10/util/Exception.h>
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
-#include <ATen/core/ivalue.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
-#include <c10/core/thread_pool.h>
+#include <torch/csrc/jit/script/logging.h>
 
 #include <exception>
 #include <iostream>
index 3a8e292..b0ef049 100644 (file)
@@ -845,6 +845,8 @@ bool Node::hasSideEffects() const {
     case prim::RaiseException:
     case prim::SetAttr:
     case aten::warn:
+    case prim::AddStatValue:
+     case prim::TimePoint:
       return true;
   }
   return false;
index 7ae656a..9e22d24 100644 (file)
@@ -8,6 +8,7 @@
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/logging.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
 
@@ -887,7 +888,46 @@ RegisterOperators reg(
          userObj->setSlot(slot, std::move(v));
          return 0;
        };
-     })});
+     })
+     });
+
+RegisterOperators logging_operators({
+    Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) {
+          auto val = pop(stack).toInt();
+          auto key = pop(stack).toString();
+
+          auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()");
+          // TODO: remove this custom tracing code once the custom op bugfix lands
+          if (jit::tracer::isTracing()) {
+            const auto& graph = tracer::getTracingState()->graph;
+            Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
+            tracer::recordSourceLocation(node);
+            node->addInput(insertConstant(*graph, key));
+            tracer::addInputs(node, "val", val);
+            graph->insertNode(node);
+          }
+          torch::jit::logging::getLogger()->addStatValue(*key, val);
+          return 0;
+    }),
+    Operator("prim::TimePoint() -> int", [](Stack& stack) {
+        auto schema = parseSchema("prim::TimePoint() -> int");
+        Node* node = nullptr;
+        // TODO: remove this custom tracing code once the custom op bugfix lands
+        if (jit::tracer::isTracing()) {
+            const auto& graph = tracer::getTracingState()->graph;
+            Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
+            tracer::recordSourceLocation(node);
+            graph->insertNode(node);
+        }
+        auto output = autograd::profiler::getTime();
+        push(stack, output);
+        if (jit::tracer::isTracing()) {
+          jit::tracer::addOutput(node, output);
+        }
+        return 0;
+    })
+});
+
 
 // define implementations for primitive number ops
 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
index f4d1c89..8175f35 100644 (file)
@@ -16,7 +16,9 @@
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/logging.h>
 #include <torch/csrc/jit/script/parser.h>
+#include <torch/csrc/jit/tracer.h>
 
 #include <torch/csrc/api/include/torch/ordered_dict.h>
 
@@ -27,6 +29,7 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 #include <pybind11/stl_bind.h>
+#include <chrono>
 #include <cstddef>
 #include <memory>
 #include <sstream>
@@ -1101,6 +1104,29 @@ void initJitScriptBindings(PyObject* module) {
       .def("run", [](testing::FileCheck& f, const Graph& g) {
         return f.run(g);
       });
+
+  m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
+    return logging::setLogger(logger);
+  }, py::return_value_policy::reference);
+  py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
+      m, "LoggerBase");
+  py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
+      .value("SUM", logging::LockingLogger::AggregationType::SUM)
+      .value("AVG", logging::LockingLogger::AggregationType::AVG)
+      .export_values();
+  py::class_<
+      logging::LockingLogger,
+      logging::LoggerBase,
+      std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
+      .def(py::init<>())
+      .def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
+      .def("get_counter_val", &logging::LockingLogger::getCounterValue);
+  py::class_<
+      logging::NoopLogger,
+      logging::LoggerBase,
+      std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
+      .def(py::init<>());
+
 }
 } // namespace script
 } // namespace jit
diff --git a/torch/csrc/jit/script/logging.cpp b/torch/csrc/jit/script/logging.cpp
new file mode 100644 (file)
index 0000000..48407cc
--- /dev/null
@@ -0,0 +1,73 @@
+#include "torch/csrc/jit/script/logging.h"
+
+#include <atomic>
+#include <mutex>
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+namespace logging {
+
+// TODO: multi-scale histogram for this thing
+
+void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
+  std::unique_lock<std::mutex> lk(m);
+  auto& raw_counter = raw_counters[stat_name];
+  raw_counter.sum += val;
+  raw_counter.count++;
+}
+
+TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
+  std::unique_lock<std::mutex> lk(m);
+  if (!raw_counters.count(name)) {
+    return 0;
+  }
+  AggregationType type = agg_types.count(name) ? agg_types.at(name)
+                                               : AggregationType::SUM;
+  const auto &raw_counter = raw_counters.at(name);
+  switch (type) {
+    case AggregationType::SUM: {
+      return raw_counter.sum;
+    } break;
+    case AggregationType::AVG: {
+      return raw_counter.sum / raw_counter.count;
+    } break;
+  }
+  throw std::runtime_error("Unknown aggregation type!");
+}
+
+void LockingLogger::setAggregationType(
+    const std::string& stat_name,
+    AggregationType type) {
+  agg_types[stat_name] = type;
+}
+
+
+std::atomic<LoggerBase*> global_logger{new NoopLogger()};
+
+LoggerBase* getLogger() {
+  return global_logger.load();
+}
+
+LoggerBase *setLogger(LoggerBase* logger) {
+  LoggerBase *previous = global_logger.load();
+  while (!global_logger.compare_exchange_strong(previous, logger)) {
+    previous = global_logger.load();
+  }
+  return previous;
+}
+
+JITTimePoint timePoint() {
+  return JITTimePoint{std::chrono::high_resolution_clock::now()};
+}
+
+void recordDurationSince(const std::string& name, JITTimePoint tp) {
+  auto end = std::chrono::high_resolution_clock::now();
+  // Measurement in microseconds.
+  auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
+  logging::getLogger()->addStatValue(name, seconds);
+}
+
+} // namespace logging
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/logging.h b/torch/csrc/jit/script/logging.h
new file mode 100644 (file)
index 0000000..60d1bc3
--- /dev/null
@@ -0,0 +1,90 @@
+#pragma once
+
+#include <memory>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+namespace torch {
+namespace jit {
+namespace logging {
+
+class LoggerBase {
+ public:
+  TORCH_API virtual void addStatValue(
+      const std::string& stat_name,
+      int64_t val) = 0;
+  virtual ~LoggerBase() {}
+};
+
+TORCH_API LoggerBase* getLogger();
+TORCH_API LoggerBase* setLogger(LoggerBase* logger);
+
+// No-op logger. This is the default and is meant to incur almost no runtime
+// overhead.
+
+class NoopLogger : public LoggerBase {
+ public:
+  void addStatValue(const std::string& stat_name, int64_t val) override {}
+  ~NoopLogger() {}
+};
+
+// Trivial locking logger. Pass in an instance of this to setLogger() to use it.
+// This keeps track of the sum of all statistics.
+//
+// NOTE: this is not written in a scalable way and should probably only be used
+// in the single-threaded case or for testing.
+class LockingLogger : public LoggerBase {
+ public:
+  TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override;
+  TORCH_API virtual int64_t getCounterValue(const std::string& name) const;
+  enum class AggregationType { SUM, AVG };
+  TORCH_API void setAggregationType(
+      const std::string& stat_name,
+      AggregationType type);
+  ~LockingLogger() {}
+
+ private:
+  mutable std::mutex m;
+  struct RawCounter {
+    RawCounter() : sum(0), count(0) {}
+    int64_t sum;
+    size_t count;
+  };
+  std::unordered_map<std::string, RawCounter> raw_counters;
+  std::unordered_map<std::string, AggregationType> agg_types;
+};
+
+// Make this struct so the timer internals are opaque to the user.
+struct JITTimePoint {
+  std::chrono::time_point<std::chrono::high_resolution_clock> point;
+};
+
+TORCH_API JITTimePoint timePoint();
+TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp);
+
+namespace runtime_counters {
+constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED =
+    "pytorch_runtime.graph_executors_constructed";
+constexpr const char* GRAPH_EXECUTOR_INVOCATIONS =
+    "pytorch_runtime.graph_executor_invocations";
+constexpr const char* EXECUTION_PLAN_CACHE_HIT =
+    "pytorch_runtime.execution_plan_cache_hit";
+constexpr const char* EXECUTION_PLAN_CACHE_MISS =
+    "pytorch_runtime.execution_plan_cache_miss";
+
+inline std::vector<const char*> allRuntimeCounters() {
+  return {GRAPH_EXECUTORS_CONSTRUCTED,
+          GRAPH_EXECUTOR_INVOCATIONS,
+          EXECUTION_PLAN_CACHE_HIT,
+          EXECUTION_PLAN_CACHE_MISS};
+}
+
+} // namespace runtime_counters
+
+} // namespace logging
+} // namespace jit
+} // namespace torch
diff --git a/torch/jit/_logging.py b/torch/jit/_logging.py
new file mode 100644 (file)
index 0000000..497c342
--- /dev/null
@@ -0,0 +1,10 @@
+import torch
+
+add_stat_value = torch.ops.prim.AddStatValue
+
+set_logger = torch._C._logging_set_logger
+LockingLogger = torch._C.LockingLogger
+AggregationType = torch._C.AggregationType
+NoopLogger = torch._C.NoopLogger
+
+time_point = torch.ops.prim.TimePoint