Quantizer pass to insert quant-dequant nodes into IR (#18446)
authorNishant Pandit <npandit@fb.com>
Sat, 6 Apr 2019 19:34:33 +0000 (12:34 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 6 Apr 2019 19:39:26 +0000 (12:39 -0700)
Summary:
- Quantizer pass to mutate IR by inserting quant-dequant nodes
before and after nodes which support quantized ops. This information
will be used by jit compiler to substitute with quantized ops

- This currently covers simple model. It will be expanded later
for subgraph pattern matching to cover more complex patterns
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18446

Differential Revision: D14592265

Pulled By: nishantpdce

fbshipit-source-id: c9ba6c12aa96cb9c117826e386721eec83a55ea6

test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/quantization.cpp
torch/csrc/jit/passes/quantization.h

index d3caa8c..98acb9b 100644 (file)
@@ -1099,9 +1099,6 @@ class TestJit(JitTestCase):
         self.run_pass('cse', graph)
         FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
 
-    def test_expand_fakequant(self):
-        pass
-
     def test_expand_propagate_qinfo(self):
         pass
 
@@ -1147,8 +1144,124 @@ class TestJit(JitTestCase):
         self.assertEqual(value_stats['p'][1], x2 + y2)
         self.assertEqual(value_stats['z'][1], x2 - y2)
 
-    def test_expand_insert_fakequant(self):
-        pass
+    def test_insert_quantdequant_consecutive_qnodes_script(self):
+        class testModule(torch.jit.ScriptModule):
+            def __init__(self):
+                super(testModule, self).__init__()
+                self.conv1 = nn.Conv2d(1, 20, 5, 1)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                x = F.relu(self.conv1(x))
+                return x
+
+        trace = testModule()
+
+        # Constant Propagation step is performed because this pass is intended
+        # to insert quant-dequant nodes for quantizable tensors. The type analysis
+        # happens as part of this jit pass
+        torch._C._jit_pass_constant_propagation(trace.graph)
+        self.run_pass('insert_quantdequant', trace.graph)
+
+        # We expect to see quant-dequant node before and after
+        # both conv and relu nodes and at external output since relu
+        # is last node. Constant nodes correspond to params for the
+        # quantization nodes
+        FileCheck().check("quantize_linear").check_next("dequantize") \
+                   .check("conv2d").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").run(str(trace.graph))
+        FileCheck().check("relu").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").check_next("return") \
+                   .run(str(trace.graph))
+
+    def test_insert_quantdequant_consecutive_qnodes_trace(self):
+        class testModule(torch.nn.Module):
+            def __init__(self):
+                super(testModule, self).__init__()
+                self.conv1 = nn.Conv2d(1, 20, 5, 1)
+
+            def forward(self, x):
+                x = F.relu(self.conv1(x))
+                return x
+
+        trace = torch.jit.trace(testModule(), (torch.rand(1, 1, 28, 28)))
+
+        self.run_pass('insert_quantdequant', trace.graph)
+        # We expect to see quant-dequant node before and after
+        # both conv and relu nodes and at external output since relu
+        # is last node. Constant nodes correspond to params for the
+        # quantization nodes
+        FileCheck().check("quantize_linear").check_next("dequantize") \
+                   .check("_convolution").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").run(str(trace.graph))
+        FileCheck().check("relu").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").check_next("return") \
+                   .run(str(trace.graph))
+
+    def test_insert_quantdequant_single_qnode(self):
+        class testModule(torch.jit.ScriptModule):
+            def __init__(self):
+                super(testModule, self).__init__()
+                self.conv1 = nn.Conv2d(1, 20, 5, 1)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                x = self.conv1(x)
+                x1 = torch.add(x, 1)
+                return x1
+
+        trace = testModule()
+
+        # Constant Propagation step is performed because this pass is intended
+        # to insert quant-dequant nodes for quantizable tensors. The type analysis
+        # happens as part of this jit pass
+        torch._C._jit_pass_constant_propagation(trace.graph)
+        self.run_pass('insert_quantdequant', trace.graph)
+
+        # We expect to see quant-dequant node before and after
+        # both conv and no quant-dequant after add. Constant nodes correspond
+        # to params for the quantization nodes
+        FileCheck().check("quantize_linear").check_next("dequantize") \
+                   .check("conv2d").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").check_next("add") \
+                   .check_next("return").run(str(trace.graph))
+
+    def test_insert_quantdequant_alternate_qnode(self):
+        class testModule(torch.jit.ScriptModule):
+            def __init__(self):
+                super(testModule, self).__init__()
+                self.conv1 = nn.Conv2d(1, 20, 5, 1)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                x = self.conv1(x)
+                x1 = torch.add(x, 1)
+                x2 = F.relu(x1)
+                return x2
+
+        trace = testModule()
+
+        # Constant Propagation step is performed because this pass is intended
+        # to insert quant-dequant nodes for quantizable tensors. The type analysis
+        # happens as part of this jit pass
+        torch._C._jit_pass_constant_propagation(trace.graph)
+        self.run_pass('insert_quantdequant', trace.graph)
+
+        # We expect to see quant-dequant node before and after
+        # conv, relu and add. Constant nodes correspond to params for the
+        # quantization nodes
+        FileCheck().check("quantize_linear").check_next("dequantize") \
+                   .check("conv2d").check_next("Constant") \
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check_next("dequantize").run(str(trace.graph))
+        FileCheck().check("add").check_next("Constant")\
+                   .check_next("Constant").check_next("quantize_linear") \
+                   .check("dequantize").run(str(trace.graph))
 
     def test_expand_quantlint(self):
         pass
index 731a5ad..9bdb9e9 100644 (file)
@@ -113,9 +113,6 @@ void initJITBindings(PyObject* module) {
             return EliminateCommonSubexpression(g); // overload resolution
           })
       .def(
-          "_jit_pass_expand_fakequant",
-          [](std::shared_ptr<Graph>& g) { return ExpandFakeQuantNodes(g); })
-      .def(
           "_jit_pass_propagate_qinfo",
           [](std::shared_ptr<Graph>& g) { return PropagateQuantInfo(g); })
       .def(
@@ -130,8 +127,8 @@ void initJITBindings(PyObject* module) {
             new_node->destroy();
           })
       .def(
-          "_jit_pass_insert_fakequant",
-          [](std::shared_ptr<Graph>& g) { return InsertFakeQuantNodes(g); })
+          "_jit_pass_insert_quantdequant",
+          [](std::shared_ptr<Graph>& g) { return InsertQuantDequantNodes(g); })
       .def(
           "_jit_pass_quantlint",
           [](std::shared_ptr<Graph>& g) { return QuantLinting(g); })
index c0006a3..6b929f2 100644 (file)
 
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
 #include <stack>
 
 namespace torch {
 namespace jit {
+namespace {
+// QuantizerUtils
 
-void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
+bool checkIfNodeQuantizable(Node* n) {
+  AT_ASSERT(n != nullptr);
+  // This is map for quantizable nodes. It will be expanded in future to
+  // support more ops and patterns.
+  static const OperatorSet quantnodeLookup =
+   {
+     "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] \
+stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
+     "aten::relu(Tensor self) -> Tensor",
+     "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] \
+stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, \
+int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor"
+   };
+  return quantnodeLookup.find(n);
+}
+
+void insertQuantNodeParams(Node* quant, std::tuple<float, int> qparam) {
+  WithInsertPoint ins(quant);
+  Value* scale = quant->owningGraph()->insertConstant(std::get<0>(qparam));
+  Value* zeropoint = quant->owningGraph()->insertConstant(std::get<1>(qparam));
+  quant->addInput(scale);
+  quant->addInput(zeropoint);
+}
+
+// Create Quant-Dequant node pair for quantizable Value
+std::pair<Node*, Node*> createQuantDeQuantNodes(Value* v, Node* n) {
+  Node* quant =
+      n->owningGraph()->create(at::Symbol::fromQualString(
+        "aten::quantize_linear"));
+  AT_ASSERTM(quant != nullptr, "Failed to create quant node");
+  quant->output()->setUniqueName(v->uniqueName() + ".quant");
+
+  Node* dequant =
+      n->owningGraph()->create(at::Symbol::fromQualString("aten::dequantize"));
+  AT_ASSERTM(dequant != nullptr, "Failed to create dequant node");
+  dequant->output()->setUniqueName(v->uniqueName() + ".dequant");
+
+  quant->setScope(n->scope());
+  dequant->setScope(n->scope());
+
+  return std::make_pair(quant, dequant);
 }
 
+// Insert Quant-Dequant node pair for quantizable node outputs
+void addQuantDeQuantNodes(Value* v) {
+  AT_ASSERT(v != nullptr);
+  Node* n = v->node();
+  auto qdq = createQuantDeQuantNodes(v, n);
+  Node* quant = qdq.first;
+  Node* dequant = qdq.second;
+
+  // Add quant-dequant nodes and replace for all uses of Value
+  quant->insertAfter(n);
+  dequant->insertAfter(quant);
+  v->replaceAllUsesWith(dequant->output());
+
+  // Attach inputs to quant and dequant nodes
+  quant->addInput(v);
+  // Default Quant Params <Scale:1.0, ZeroPoint:0>
+  insertQuantNodeParams(quant, std::make_tuple(1.0, 0));
+  dequant->addInput(quant->output());
+}
+
+// Insert Quant-Dequant node pair for specific input to node n
+void addQuantDeQuantNodesForInput(Value* v, Node* n) {
+  AT_ASSERT(v != nullptr);
+  AT_ASSERT(n != nullptr);
+  auto qdq = createQuantDeQuantNodes(v, n);
+  Node* quant = qdq.first;
+  Node* dequant = qdq.second;
+
+  // Insert the quant-dequant node for the V->N
+  // pair which is identified as quantizable during
+  // graph iteration
+  dequant->insertBefore(n);
+  quant->insertBefore(dequant);
+  n->replaceInputWith(v, dequant->output());
+
+  // Attach inputs to quant and dequant nodes
+  quant->addInput(v);
+  // Default Quant Params <Scale:1.0, ZeroPoint:0>
+  insertQuantNodeParams(quant, std::make_tuple(1.0, 0));
+  dequant->addInput(quant->output());
+}
+
+} // namespace
+
+// PyBind APIs
 void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
   throw std::runtime_error("Pass not implemented yet!");
 }
@@ -48,7 +135,7 @@ static void addObserverFor(Value* v, Node* original_observer_node) {
 }
 
 static bool outputsNeedToBeObserved(Node* n) {
-  return n->kind().toQualString() != std::string("prim::Constant");
+  return n->kind() != prim::Constant;
 }
 
 void InsertObserverNodes(std::shared_ptr<Graph>& graph, Node* observer_node) {
@@ -87,8 +174,84 @@ void InsertObserverNodes(std::shared_ptr<Graph>& graph, Node* observer_node) {
   }
 }
 
-void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
+void InsertQuantDequantNodes(std::shared_ptr<Graph>& graph) {
+  std::stack<Block*> blocks_to_visit;
+  blocks_to_visit.push(graph->block());
+  // For storing quantizable values - node pairs that are external
+  // or intermediate inputs to quantizable nodes
+  std::vector<std::pair<Value*, Node*>> quantInputs;
+  // For storing quantizable values that are output of quantizable nodes
+  // Since same value can go to multiple nodes, we use set so that
+  // we insert quant-dequant node pairs for value only once
+  std::vector<Value*> quantOutputs;
+  std::unordered_set<Value*> valueLookup;
+
+  while (!blocks_to_visit.empty()) {
+    Block* b = blocks_to_visit.top();
+    blocks_to_visit.pop();
+
+    for (Node* n : b->nodes()) {
+      // Schedule the sub blocks
+      for (Block* subblock : n->blocks()) {
+        blocks_to_visit.push(subblock);
+      }
+
+      // We iterate over node inputs to identify which Values
+      // need to be quantized depending on node type
+      for (auto &v : n->inputs()) {
+        if (!v->type()->isSubtypeOf(TensorType::get())) {
+          // Skip quantization for non tensors
+          continue;
+        }
+
+        if (checkIfNodeQuantizable(v->node())) {
+          // Goal of this iteration is to identify the parent node for V
+          // that is quantizable and replace all uses of Value with
+          // quant-dequant output. Usage of set helps adding single
+          // q-dq nodes for all V->users
+          // Example N1 -> (V1 -> (N2), V2 -> (N3))
+          //         N1 is quantizable node. So we insert quant-dequant
+          //         nodes for all outputs of N1 (V1, V2) once
+          if (!valueLookup.count(v)) {
+            valueLookup.emplace(v);
+            quantOutputs.emplace_back(v);
+          }
+        } else if (checkIfNodeQuantizable(n)) {
+          // Goal of this iteration is to identify nodes that are
+          // quantizable but input value originate from non quantizable
+          // node. This requires selectively inserting q-dq nodes for
+          // inputs into node N(V, N pair) because parent node might
+          // also have input into other non quantizable nodes
+          // Example : N1(prim::Param) -> (V1 -> (N4, N5), V2 -> (N6, N7), V3)
+          //           N1 is not quantizable node but N4 and N7 are
+          //           quantizable nodes. So we add the (V1, N4) and
+          //           (V2, N7) as insertion points for quant-dequant nodes
+          quantInputs.emplace_back(v, n);
+        }
+      }
+    } // End Loop for nodes within block
+
+    // Since we are iterating node inputs only above, we need to iterate
+    // over block outputs values and if they originate from quantizable
+    // node, we push to quantOutputs set
+    auto outputVals = b->outputs();
+    for (auto& v : outputVals) {
+      if (checkIfNodeQuantizable(v->node()) &&
+        v->type()->isSubtypeOf(TensorType::get())) {
+        quantOutputs.emplace_back(v);
+      }
+    } //end for
+  } // end Block traversal
+
+  // Insert the quant-dequant pair for values output from quantizable nodes
+  for (auto& ele : quantOutputs) {
+    addQuantDeQuantNodes(ele);
+  }
+
+  // Insert the quant-dequant pair for values inputs to quantizable nodes
+  for (auto& ele : quantInputs) {
+    addQuantDeQuantNodesForInput(ele.first, ele.second);
+  }
 }
 
 void QuantLinting(std::shared_ptr<Graph>& graph) {
index 8094cbf..2812d61 100644 (file)
 namespace torch {
 namespace jit {
 
-/** \brief Replace all FakeQuant nodes with corresponding Quant-Dequant nodes
- * pair. */
-TORCH_API void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph);
-
 /** \brief Propagates QParams through nodes that are not supposed to change it.
  *
  * An example of such node is `Split`: even though the observed distribution
@@ -34,18 +30,19 @@ TORCH_API void InsertObserverNodes(
     std::shared_ptr<Graph>& graph,
     Node* observer_node);
 
-/** \brief Inserts fake-quant nodes.
+/** \brief Inserts quant-dequant nodes.
  *
  * This actually changes the numerical semantics of the original model and thus
  * we only run it when user explicitly wants that. This pass essentially
- * performs quantization of the model - later passes only cleanup the IR and
+ * performs quantization of the model by inserting quant-dequant node pairs for
+ * quantizatable tensors - later passes only cleanup the IR and
  * make sure the model runs faster/consumes less memory.
  *
  * TODO: This should also take a qparam-map as an input.
  */
-TORCH_API void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph);
+TORCH_API void InsertQuantDequantNodes(std::shared_ptr<Graph>& graph);
 
-/** \brief Check that all expected optimizations after fake-quant nodes
+/** \brief Check that all expected optimizations after quant-dequant nodes
  * insertion actually happened.
  *
  * Even though semantically it would be correct to just execute the initial