From fca9d9a1003dbd96820be2cf06e338652bac8ddf Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Fri, 29 Mar 2019 15:01:36 -0700 Subject: [PATCH] Initial implementation of InsertObserverNodes pass. (#18152) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18152 ghimport-source-id: 1dd5e62c4d93394dcd8d8af2871554575c8d3d1a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18152 Initial implementation of InsertObserverNodes pass.** * #18151 Add quant-passes stubs. gh-metadata: pytorch pytorch 18150 gh/zolotukhinm@gmail.com/2/head Differential Revision: D14584223 fbshipit-source-id: 30896acc1a8901d22c6a167eb87d2fbaafbbeb6f --- test/test_jit.py | 43 +++++++++++++++++++- torch/csrc/jit/init.cpp | 10 ++++- torch/csrc/jit/passes/quantization.cpp | 73 ++++++++++++++++++++++++++++++++-- torch/csrc/jit/passes/quantization.h | 7 +++- 4 files changed, 125 insertions(+), 8 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index e51ed7b..6229f0c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -996,8 +996,47 @@ class TestJit(JitTestCase): def test_expand_propagate_qinfo(self): pass - def test_expand_insert_observers(self): - pass + def test_insert_observers(self): + x1 = torch.tensor([0.4, 0.3]) + y1 = torch.tensor([0.7, 0.5]) + x2 = torch.tensor([0.1, 0.9]) + y2 = torch.tensor([1.1, 1.9]) + + # Function that we will use as a graph + def fn(x, y): + p = x + y + z = x - y + return p * z + + # Custom observer function + value_stats = {} + + def observe(x, name): + if name not in value_stats: + value_stats[name] = [] + value_stats[name].append(x) + return x + + m = torch.jit.script(fn) + # Insert observers + torch._C._jit_pass_insert_observers(m.graph, observe) + + # Collect statistics + m.forward(x1, y1) + + # Check what we collected + self.assertTrue('p' in value_stats and 'z' in value_stats) + self.assertEqual(len(value_stats['p']), 1) + self.assertEqual(len(value_stats['z']), 1) + self.assertEqual(value_stats['p'][0], x1 + y1) + self.assertEqual(value_stats['z'][0], x1 - y1) + + # Run one more time and check the updated statistics + m.forward(x2, y2) + self.assertEqual(len(value_stats['p']), 2) + self.assertEqual(len(value_stats['z']), 2) + self.assertEqual(value_stats['p'][1], x2 + y2) + self.assertEqual(value_stats['z'][1], x2 - y2) def test_expand_insert_fakequant(self): pass diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index bd66058..489c438 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -120,7 +120,15 @@ void initJITBindings(PyObject* module) { [](std::shared_ptr& g) { return PropagateQuantInfo(g); }) .def( "_jit_pass_insert_observers", - [](std::shared_ptr& g) { return InsertObserverNodes(g); }) + [](std::shared_ptr& g, py::function pyObserverFunction) { + // Create a new node that would be used in the insert observer pass: + // all observer nodes will be cloned from this one. + Node* new_node = g->createPythonOp( + THPObjectPtr(pyObserverFunction.release().ptr()), "dd", {}); + InsertObserverNodes(g, new_node); + // We don't need this node anymore, don't forget to remove it. + new_node->destroy(); + }) .def( "_jit_pass_insert_fakequant", [](std::shared_ptr& g) { return InsertFakeQuantNodes(g); }) diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp index 1d46d3e..c0006a3 100644 --- a/torch/csrc/jit/passes/quantization.cpp +++ b/torch/csrc/jit/passes/quantization.cpp @@ -4,11 +4,10 @@ #include #include -#include +#include namespace torch { namespace jit { -namespace {} // namespace void ExpandFakeQuantNodes(std::shared_ptr& graph) { throw std::runtime_error("Pass not implemented yet!"); @@ -18,8 +17,74 @@ void PropagateQuantInfo(std::shared_ptr& graph) { throw std::runtime_error("Pass not implemented yet!"); } -void InsertObserverNodes(std::shared_ptr& graph) { - throw std::runtime_error("Pass not implemented yet!"); +static void addObserverFor(Value* v, Node* original_observer_node) { + Node* def = v->node(); + WithInsertPoint ins(def); + + // We need to pass the value name to observer function - create a constant + // holding this name. + Value* vname = def->owningGraph()->insertConstant(v->uniqueName()); + + // Create a new observer node. We just need to clone the original one. + Node* observerNode = + def->owningGraph() + ->createClone( + &*original_observer_node, [&](Value* v) { return v; }, false) + ->insertAfter(def); + + // Set the type and the name of the output of the new observer node. It will + // be used instead of the original value v. + Value* observedValue = observerNode->addOutput(); + observedValue->setType(v->type()); + observedValue->setUniqueName(v->uniqueName() + ".observed"); + + // Replace the uses of v with observedValue. We need to do it *before* we add + // the inputs - otherwise we would replace the newly added inputs as well. + v->replaceAllUsesWith(observedValue); + + // Now we can add the inputs. + observerNode->addInput(v); + observerNode->addInput(vname); +} + +static bool outputsNeedToBeObserved(Node* n) { + return n->kind().toQualString() != std::string("prim::Constant"); +} + +void InsertObserverNodes(std::shared_ptr& graph, Node* observer_node) { + // For storing all values that need to be instrumented with an observer call. + std::vector values_to_observe; + + // For traversing all blocks in the graph including subblocks. + std::stack blocks_to_visit; + + blocks_to_visit.push(graph->block()); + while (!blocks_to_visit.empty()) { + Block* b = blocks_to_visit.top(); + blocks_to_visit.pop(); + for (Node* n : b->nodes()) { + // Skip nodes that we don't need to observe, e.g. 'prim::Constant'. + if (!outputsNeedToBeObserved(n)) { + continue; + } + + // Record all outputs in the values_to_observe - we'll later add observers + // for all values from it. + for (Value* v : n->outputs()) { + values_to_observe.push_back(v); + } + + // Schedule subblocks (if any) for visiting. + for (Block* subblock : n->blocks()) { + blocks_to_visit.push(subblock); + } + } + } + + // Actually add observer nodes. + for (Value* v : values_to_observe) { + addObserverFor(v, observer_node); + } } void InsertFakeQuantNodes(std::shared_ptr& graph) { diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h index 37f558d..8094cbf 100644 --- a/torch/csrc/jit/passes/quantization.h +++ b/torch/csrc/jit/passes/quantization.h @@ -26,8 +26,13 @@ TORCH_API void PropagateQuantInfo(std::shared_ptr& graph); * a tensor. * * The distribution can then be used for computing qparams for quantization. + * \param graph is the graph that would be instrumented. + * \param observer_node is a Node representing a call to observer function. It + * will be cloned into all the places where we need to add instrumentation. */ -TORCH_API void InsertObserverNodes(std::shared_ptr& graph); +TORCH_API void InsertObserverNodes( + std::shared_ptr& graph, + Node* observer_node); /** \brief Inserts fake-quant nodes. * -- 2.7.4