Initial implementation of InsertObserverNodes pass. (#18152)
authorMikhail Zolotukhin <mvz@fb.com>
Fri, 29 Mar 2019 22:01:36 +0000 (15:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 22:08:57 +0000 (15:08 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18152
ghimport-source-id: 1dd5e62c4d93394dcd8d8af2871554575c8d3d1a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18152 Initial implementation of InsertObserverNodes pass.**
* #18151 Add quant-passes stubs.

gh-metadata: pytorch pytorch 18150 gh/zolotukhinm@gmail.com/2/head

Differential Revision: D14584223

fbshipit-source-id: 30896acc1a8901d22c6a167eb87d2fbaafbbeb6f

test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/quantization.cpp
torch/csrc/jit/passes/quantization.h

index e51ed7b..6229f0c 100644 (file)
@@ -996,8 +996,47 @@ class TestJit(JitTestCase):
     def test_expand_propagate_qinfo(self):
         pass
 
-    def test_expand_insert_observers(self):
-        pass
+    def test_insert_observers(self):
+        x1 = torch.tensor([0.4, 0.3])
+        y1 = torch.tensor([0.7, 0.5])
+        x2 = torch.tensor([0.1, 0.9])
+        y2 = torch.tensor([1.1, 1.9])
+
+        # Function that we will use as a graph
+        def fn(x, y):
+            p = x + y
+            z = x - y
+            return p * z
+
+        # Custom observer function
+        value_stats = {}
+
+        def observe(x, name):
+            if name not in value_stats:
+                value_stats[name] = []
+            value_stats[name].append(x)
+            return x
+
+        m = torch.jit.script(fn)
+        # Insert observers
+        torch._C._jit_pass_insert_observers(m.graph, observe)
+
+        # Collect statistics
+        m.forward(x1, y1)
+
+        # Check what we collected
+        self.assertTrue('p' in value_stats and 'z' in value_stats)
+        self.assertEqual(len(value_stats['p']), 1)
+        self.assertEqual(len(value_stats['z']), 1)
+        self.assertEqual(value_stats['p'][0], x1 + y1)
+        self.assertEqual(value_stats['z'][0], x1 - y1)
+
+        # Run one more time and check the updated statistics
+        m.forward(x2, y2)
+        self.assertEqual(len(value_stats['p']), 2)
+        self.assertEqual(len(value_stats['z']), 2)
+        self.assertEqual(value_stats['p'][1], x2 + y2)
+        self.assertEqual(value_stats['z'][1], x2 - y2)
 
     def test_expand_insert_fakequant(self):
         pass
index bd66058..489c438 100644 (file)
@@ -120,7 +120,15 @@ void initJITBindings(PyObject* module) {
           [](std::shared_ptr<Graph>& g) { return PropagateQuantInfo(g); })
       .def(
           "_jit_pass_insert_observers",
-          [](std::shared_ptr<Graph>& g) { return InsertObserverNodes(g); })
+          [](std::shared_ptr<Graph>& g, py::function pyObserverFunction) {
+            // Create a new node that would be used in the insert observer pass:
+            // all observer nodes will be cloned from this one.
+            Node* new_node = g->createPythonOp(
+                THPObjectPtr(pyObserverFunction.release().ptr()), "dd", {});
+            InsertObserverNodes(g, new_node);
+            // We don't need this node anymore, don't forget to remove it.
+            new_node->destroy();
+          })
       .def(
           "_jit_pass_insert_fakequant",
           [](std::shared_ptr<Graph>& g) { return InsertFakeQuantNodes(g); })
index 1d46d3e..c0006a3 100644 (file)
@@ -4,11 +4,10 @@
 #include <torch/csrc/jit/node_hashing.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
-#include <unordered_map>
+#include <stack>
 
 namespace torch {
 namespace jit {
-namespace {} // namespace
 
 void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph) {
   throw std::runtime_error("Pass not implemented yet!");
@@ -18,8 +17,74 @@ void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
   throw std::runtime_error("Pass not implemented yet!");
 }
 
-void InsertObserverNodes(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
+static void addObserverFor(Value* v, Node* original_observer_node) {
+  Node* def = v->node();
+  WithInsertPoint ins(def);
+
+  // We need to pass the value name to observer function - create a constant
+  // holding this name.
+  Value* vname = def->owningGraph()->insertConstant(v->uniqueName());
+
+  // Create a new observer node. We just need to clone the original one.
+  Node* observerNode =
+      def->owningGraph()
+          ->createClone(
+              &*original_observer_node, [&](Value* v) { return v; }, false)
+          ->insertAfter(def);
+
+  // Set the type and the name of the output of the new observer node. It will
+  // be used instead of the original value v.
+  Value* observedValue = observerNode->addOutput();
+  observedValue->setType(v->type());
+  observedValue->setUniqueName(v->uniqueName() + ".observed");
+
+  // Replace the uses of v with observedValue. We need to do it *before* we add
+  // the inputs - otherwise we would replace the newly added inputs as well.
+  v->replaceAllUsesWith(observedValue);
+
+  // Now we can add the inputs.
+  observerNode->addInput(v);
+  observerNode->addInput(vname);
+}
+
+static bool outputsNeedToBeObserved(Node* n) {
+  return n->kind().toQualString() != std::string("prim::Constant");
+}
+
+void InsertObserverNodes(std::shared_ptr<Graph>& graph, Node* observer_node) {
+  // For storing all values that need to be instrumented with an observer call.
+  std::vector<Value*> values_to_observe;
+
+  // For traversing all blocks in the graph including subblocks.
+  std::stack<Block*> blocks_to_visit;
+
+  blocks_to_visit.push(graph->block());
+  while (!blocks_to_visit.empty()) {
+    Block* b = blocks_to_visit.top();
+    blocks_to_visit.pop();
+    for (Node* n : b->nodes()) {
+      // Skip nodes that we don't need to observe, e.g. 'prim::Constant'.
+      if (!outputsNeedToBeObserved(n)) {
+        continue;
+      }
+
+      // Record all outputs in the values_to_observe - we'll later add observers
+      // for all values from it.
+      for (Value* v : n->outputs()) {
+        values_to_observe.push_back(v);
+      }
+
+      // Schedule subblocks (if any) for visiting.
+      for (Block* subblock : n->blocks()) {
+        blocks_to_visit.push(subblock);
+      }
+    }
+  }
+
+  // Actually add observer nodes.
+  for (Value* v : values_to_observe) {
+    addObserverFor(v, observer_node);
+  }
 }
 
 void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph) {
index 37f558d..8094cbf 100644 (file)
@@ -26,8 +26,13 @@ TORCH_API void PropagateQuantInfo(std::shared_ptr<Graph>& graph);
  * a tensor.
  *
  * The distribution can then be used for computing qparams for quantization.
+ * \param graph is the graph that would be instrumented.
+ * \param observer_node is a Node representing a call to observer function. It
+ * will be cloned into all the places where we need to add instrumentation.
  */
-TORCH_API void InsertObserverNodes(std::shared_ptr<Graph>& graph);
+TORCH_API void InsertObserverNodes(
+    std::shared_ptr<Graph>& graph,
+    Node* observer_node);
 
 /** \brief Inserts fake-quant nodes.
  *