From: Nishant Pandit Date: Sat, 6 Apr 2019 19:34:33 +0000 (-0700) Subject: Quantizer pass to insert quant-dequant nodes into IR (#18446) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~359 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bcd527190a6aed12c6926624256a448d5461be98;p=platform%2Fupstream%2Fpytorch.git Quantizer pass to insert quant-dequant nodes into IR (#18446) Summary: - Quantizer pass to mutate IR by inserting quant-dequant nodes before and after nodes which support quantized ops. This information will be used by jit compiler to substitute with quantized ops - This currently covers simple model. It will be expanded later for subgraph pattern matching to cover more complex patterns Pull Request resolved: https://github.com/pytorch/pytorch/pull/18446 Differential Revision: D14592265 Pulled By: nishantpdce fbshipit-source-id: c9ba6c12aa96cb9c117826e386721eec83a55ea6 --- diff --git a/test/test_jit.py b/test/test_jit.py index d3caa8c..98acb9b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1099,9 +1099,6 @@ class TestJit(JitTestCase): self.run_pass('cse', graph) FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph)) - def test_expand_fakequant(self): - pass - def test_expand_propagate_qinfo(self): pass @@ -1147,8 +1144,124 @@ class TestJit(JitTestCase): self.assertEqual(value_stats['p'][1], x2 + y2) self.assertEqual(value_stats['z'][1], x2 - y2) - def test_expand_insert_fakequant(self): - pass + def test_insert_quantdequant_consecutive_qnodes_script(self): + class testModule(torch.jit.ScriptModule): + def __init__(self): + super(testModule, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + + @torch.jit.script_method + def forward(self, x): + x = F.relu(self.conv1(x)) + return x + + trace = testModule() + + # Constant Propagation step is performed because this pass is intended + # to insert quant-dequant nodes for quantizable tensors. The type analysis + # happens as part of this jit pass + torch._C._jit_pass_constant_propagation(trace.graph) + self.run_pass('insert_quantdequant', trace.graph) + + # We expect to see quant-dequant node before and after + # both conv and relu nodes and at external output since relu + # is last node. Constant nodes correspond to params for the + # quantization nodes + FileCheck().check("quantize_linear").check_next("dequantize") \ + .check("conv2d").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").run(str(trace.graph)) + FileCheck().check("relu").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").check_next("return") \ + .run(str(trace.graph)) + + def test_insert_quantdequant_consecutive_qnodes_trace(self): + class testModule(torch.nn.Module): + def __init__(self): + super(testModule, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return x + + trace = torch.jit.trace(testModule(), (torch.rand(1, 1, 28, 28))) + + self.run_pass('insert_quantdequant', trace.graph) + # We expect to see quant-dequant node before and after + # both conv and relu nodes and at external output since relu + # is last node. Constant nodes correspond to params for the + # quantization nodes + FileCheck().check("quantize_linear").check_next("dequantize") \ + .check("_convolution").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").run(str(trace.graph)) + FileCheck().check("relu").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").check_next("return") \ + .run(str(trace.graph)) + + def test_insert_quantdequant_single_qnode(self): + class testModule(torch.jit.ScriptModule): + def __init__(self): + super(testModule, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + + @torch.jit.script_method + def forward(self, x): + x = self.conv1(x) + x1 = torch.add(x, 1) + return x1 + + trace = testModule() + + # Constant Propagation step is performed because this pass is intended + # to insert quant-dequant nodes for quantizable tensors. The type analysis + # happens as part of this jit pass + torch._C._jit_pass_constant_propagation(trace.graph) + self.run_pass('insert_quantdequant', trace.graph) + + # We expect to see quant-dequant node before and after + # both conv and no quant-dequant after add. Constant nodes correspond + # to params for the quantization nodes + FileCheck().check("quantize_linear").check_next("dequantize") \ + .check("conv2d").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").check_next("add") \ + .check_next("return").run(str(trace.graph)) + + def test_insert_quantdequant_alternate_qnode(self): + class testModule(torch.jit.ScriptModule): + def __init__(self): + super(testModule, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + + @torch.jit.script_method + def forward(self, x): + x = self.conv1(x) + x1 = torch.add(x, 1) + x2 = F.relu(x1) + return x2 + + trace = testModule() + + # Constant Propagation step is performed because this pass is intended + # to insert quant-dequant nodes for quantizable tensors. The type analysis + # happens as part of this jit pass + torch._C._jit_pass_constant_propagation(trace.graph) + self.run_pass('insert_quantdequant', trace.graph) + + # We expect to see quant-dequant node before and after + # conv, relu and add. Constant nodes correspond to params for the + # quantization nodes + FileCheck().check("quantize_linear").check_next("dequantize") \ + .check("conv2d").check_next("Constant") \ + .check_next("Constant").check_next("quantize_linear") \ + .check_next("dequantize").run(str(trace.graph)) + FileCheck().check("add").check_next("Constant")\ + .check_next("Constant").check_next("quantize_linear") \ + .check("dequantize").run(str(trace.graph)) def test_expand_quantlint(self): pass diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 731a5ad..9bdb9e9 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -113,9 +113,6 @@ void initJITBindings(PyObject* module) { return EliminateCommonSubexpression(g); // overload resolution }) .def( - "_jit_pass_expand_fakequant", - [](std::shared_ptr& g) { return ExpandFakeQuantNodes(g); }) - .def( "_jit_pass_propagate_qinfo", [](std::shared_ptr& g) { return PropagateQuantInfo(g); }) .def( @@ -130,8 +127,8 @@ void initJITBindings(PyObject* module) { new_node->destroy(); }) .def( - "_jit_pass_insert_fakequant", - [](std::shared_ptr& g) { return InsertFakeQuantNodes(g); }) + "_jit_pass_insert_quantdequant", + [](std::shared_ptr& g) { return InsertQuantDequantNodes(g); }) .def( "_jit_pass_quantlint", [](std::shared_ptr& g) { return QuantLinting(g); }) diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp index c0006a3..6b929f2 100644 --- a/torch/csrc/jit/passes/quantization.cpp +++ b/torch/csrc/jit/passes/quantization.cpp @@ -2,17 +2,104 @@ #include #include +#include #include #include namespace torch { namespace jit { +namespace { +// QuantizerUtils -void ExpandFakeQuantNodes(std::shared_ptr& graph) { - throw std::runtime_error("Pass not implemented yet!"); +bool checkIfNodeQuantizable(Node* n) { + AT_ASSERT(n != nullptr); + // This is map for quantizable nodes. It will be expanded in future to + // support more ops and patterns. + static const OperatorSet quantnodeLookup = + { + "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] \ +stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", + "aten::relu(Tensor self) -> Tensor", + "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] \ +stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, \ +int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor" + }; + return quantnodeLookup.find(n); +} + +void insertQuantNodeParams(Node* quant, std::tuple qparam) { + WithInsertPoint ins(quant); + Value* scale = quant->owningGraph()->insertConstant(std::get<0>(qparam)); + Value* zeropoint = quant->owningGraph()->insertConstant(std::get<1>(qparam)); + quant->addInput(scale); + quant->addInput(zeropoint); +} + +// Create Quant-Dequant node pair for quantizable Value +std::pair createQuantDeQuantNodes(Value* v, Node* n) { + Node* quant = + n->owningGraph()->create(at::Symbol::fromQualString( + "aten::quantize_linear")); + AT_ASSERTM(quant != nullptr, "Failed to create quant node"); + quant->output()->setUniqueName(v->uniqueName() + ".quant"); + + Node* dequant = + n->owningGraph()->create(at::Symbol::fromQualString("aten::dequantize")); + AT_ASSERTM(dequant != nullptr, "Failed to create dequant node"); + dequant->output()->setUniqueName(v->uniqueName() + ".dequant"); + + quant->setScope(n->scope()); + dequant->setScope(n->scope()); + + return std::make_pair(quant, dequant); } +// Insert Quant-Dequant node pair for quantizable node outputs +void addQuantDeQuantNodes(Value* v) { + AT_ASSERT(v != nullptr); + Node* n = v->node(); + auto qdq = createQuantDeQuantNodes(v, n); + Node* quant = qdq.first; + Node* dequant = qdq.second; + + // Add quant-dequant nodes and replace for all uses of Value + quant->insertAfter(n); + dequant->insertAfter(quant); + v->replaceAllUsesWith(dequant->output()); + + // Attach inputs to quant and dequant nodes + quant->addInput(v); + // Default Quant Params + insertQuantNodeParams(quant, std::make_tuple(1.0, 0)); + dequant->addInput(quant->output()); +} + +// Insert Quant-Dequant node pair for specific input to node n +void addQuantDeQuantNodesForInput(Value* v, Node* n) { + AT_ASSERT(v != nullptr); + AT_ASSERT(n != nullptr); + auto qdq = createQuantDeQuantNodes(v, n); + Node* quant = qdq.first; + Node* dequant = qdq.second; + + // Insert the quant-dequant node for the V->N + // pair which is identified as quantizable during + // graph iteration + dequant->insertBefore(n); + quant->insertBefore(dequant); + n->replaceInputWith(v, dequant->output()); + + // Attach inputs to quant and dequant nodes + quant->addInput(v); + // Default Quant Params + insertQuantNodeParams(quant, std::make_tuple(1.0, 0)); + dequant->addInput(quant->output()); +} + +} // namespace + +// PyBind APIs void PropagateQuantInfo(std::shared_ptr& graph) { throw std::runtime_error("Pass not implemented yet!"); } @@ -48,7 +135,7 @@ static void addObserverFor(Value* v, Node* original_observer_node) { } static bool outputsNeedToBeObserved(Node* n) { - return n->kind().toQualString() != std::string("prim::Constant"); + return n->kind() != prim::Constant; } void InsertObserverNodes(std::shared_ptr& graph, Node* observer_node) { @@ -87,8 +174,84 @@ void InsertObserverNodes(std::shared_ptr& graph, Node* observer_node) { } } -void InsertFakeQuantNodes(std::shared_ptr& graph) { - throw std::runtime_error("Pass not implemented yet!"); +void InsertQuantDequantNodes(std::shared_ptr& graph) { + std::stack blocks_to_visit; + blocks_to_visit.push(graph->block()); + // For storing quantizable values - node pairs that are external + // or intermediate inputs to quantizable nodes + std::vector> quantInputs; + // For storing quantizable values that are output of quantizable nodes + // Since same value can go to multiple nodes, we use set so that + // we insert quant-dequant node pairs for value only once + std::vector quantOutputs; + std::unordered_set valueLookup; + + while (!blocks_to_visit.empty()) { + Block* b = blocks_to_visit.top(); + blocks_to_visit.pop(); + + for (Node* n : b->nodes()) { + // Schedule the sub blocks + for (Block* subblock : n->blocks()) { + blocks_to_visit.push(subblock); + } + + // We iterate over node inputs to identify which Values + // need to be quantized depending on node type + for (auto &v : n->inputs()) { + if (!v->type()->isSubtypeOf(TensorType::get())) { + // Skip quantization for non tensors + continue; + } + + if (checkIfNodeQuantizable(v->node())) { + // Goal of this iteration is to identify the parent node for V + // that is quantizable and replace all uses of Value with + // quant-dequant output. Usage of set helps adding single + // q-dq nodes for all V->users + // Example N1 -> (V1 -> (N2), V2 -> (N3)) + // N1 is quantizable node. So we insert quant-dequant + // nodes for all outputs of N1 (V1, V2) once + if (!valueLookup.count(v)) { + valueLookup.emplace(v); + quantOutputs.emplace_back(v); + } + } else if (checkIfNodeQuantizable(n)) { + // Goal of this iteration is to identify nodes that are + // quantizable but input value originate from non quantizable + // node. This requires selectively inserting q-dq nodes for + // inputs into node N(V, N pair) because parent node might + // also have input into other non quantizable nodes + // Example : N1(prim::Param) -> (V1 -> (N4, N5), V2 -> (N6, N7), V3) + // N1 is not quantizable node but N4 and N7 are + // quantizable nodes. So we add the (V1, N4) and + // (V2, N7) as insertion points for quant-dequant nodes + quantInputs.emplace_back(v, n); + } + } + } // End Loop for nodes within block + + // Since we are iterating node inputs only above, we need to iterate + // over block outputs values and if they originate from quantizable + // node, we push to quantOutputs set + auto outputVals = b->outputs(); + for (auto& v : outputVals) { + if (checkIfNodeQuantizable(v->node()) && + v->type()->isSubtypeOf(TensorType::get())) { + quantOutputs.emplace_back(v); + } + } //end for + } // end Block traversal + + // Insert the quant-dequant pair for values output from quantizable nodes + for (auto& ele : quantOutputs) { + addQuantDeQuantNodes(ele); + } + + // Insert the quant-dequant pair for values inputs to quantizable nodes + for (auto& ele : quantInputs) { + addQuantDeQuantNodesForInput(ele.first, ele.second); + } } void QuantLinting(std::shared_ptr& graph) { diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h index 8094cbf..2812d61 100644 --- a/torch/csrc/jit/passes/quantization.h +++ b/torch/csrc/jit/passes/quantization.h @@ -10,10 +10,6 @@ namespace torch { namespace jit { -/** \brief Replace all FakeQuant nodes with corresponding Quant-Dequant nodes - * pair. */ -TORCH_API void ExpandFakeQuantNodes(std::shared_ptr& graph); - /** \brief Propagates QParams through nodes that are not supposed to change it. * * An example of such node is `Split`: even though the observed distribution @@ -34,18 +30,19 @@ TORCH_API void InsertObserverNodes( std::shared_ptr& graph, Node* observer_node); -/** \brief Inserts fake-quant nodes. +/** \brief Inserts quant-dequant nodes. * * This actually changes the numerical semantics of the original model and thus * we only run it when user explicitly wants that. This pass essentially - * performs quantization of the model - later passes only cleanup the IR and + * performs quantization of the model by inserting quant-dequant node pairs for + * quantizatable tensors - later passes only cleanup the IR and * make sure the model runs faster/consumes less memory. * * TODO: This should also take a qparam-map as an input. */ -TORCH_API void InsertFakeQuantNodes(std::shared_ptr& graph); +TORCH_API void InsertQuantDequantNodes(std::shared_ptr& graph); -/** \brief Check that all expected optimizations after fake-quant nodes +/** \brief Check that all expected optimizations after quant-dequant nodes * insertion actually happened. * * Even though semantically it would be correct to just execute the initial