Add quant-passes stubs. (#18151)
authorMikhail Zolotukhin <mvz@fb.com>
Tue, 26 Mar 2019 00:39:01 +0000 (17:39 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 00:48:54 +0000 (17:48 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18151
ghimport-source-id: 7d12462971bdf3e5e26a3f150f1fcad05bba1a15

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18152 Initial implementation of InsertObserverNodes pass.
* **#18151 Add quant-passes stubs.**

gh-metadata: pytorch pytorch 18149 gh/zolotukhinm@gmail.com/1/head

Differential Revision: D14584224

fbshipit-source-id: b3d0b5ff797160d5ad23f91f732e627b0129086c

test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/quantization.cpp [new file with mode: 0644]
torch/csrc/jit/passes/quantization.h [new file with mode: 0644]

index 36d51b9..1c59929 100644 (file)
@@ -991,6 +991,24 @@ class TestJit(JitTestCase):
         self.run_pass('cse', graph)
         FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
 
+    def test_expand_fakequant(self):
+        pass
+
+    def test_expand_propagate_qinfo(self):
+        pass
+
+    def test_expand_insert_observers(self):
+        pass
+
+    def test_expand_insert_fakequant(self):
+        pass
+
+    def test_expand_quantlint(self):
+        pass
+
+    def test_expand_fold_quant_inputs(self):
+        pass
+
     def test_shape_analysis_broadcast(self):
         def broadcast(a, b):
             return a + b
index 675ae7d..503eb6a 100644 (file)
@@ -81,6 +81,7 @@ libtorch_sources = [
     "torch/csrc/jit/passes/lower_tuples.cpp",
     "torch/csrc/jit/passes/peephole.cpp",
     "torch/csrc/jit/passes/python_print.cpp",
+    "torch/csrc/jit/passes/quantization.cpp",
     "torch/csrc/jit/passes/remove_expands.cpp",
     "torch/csrc/jit/passes/requires_grad_analysis.cpp",
     "torch/csrc/jit/passes/shape_analysis.cpp",
index 71f9908..c2d7783 100644 (file)
@@ -166,6 +166,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
index b0888cd..f46dd9f 100644 (file)
@@ -26,6 +26,7 @@
 #include <torch/csrc/jit/passes/onnx/peephole.h>
 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
 #include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/quantization.h>
 #include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
@@ -114,6 +115,26 @@ void initJITBindings(PyObject* module) {
             return EliminateCommonSubexpression(g); // overload resolution
           })
       .def(
+          "_jit_pass_expand_fakequant",
+          [](std::shared_ptr<Graph>& g) { return ExpandFakeQuantNodes(g); })
+      .def(
+          "_jit_pass_propagate_qinfo",
+          [](std::shared_ptr<Graph>& g) { return PropagateQuantInfo(g); })
+      .def(
+          "_jit_pass_insert_observers",
+          [](std::shared_ptr<Graph>& g) { return InsertObserverNodes(g); })
+      .def(
+          "_jit_pass_insert_fakequant",
+          [](std::shared_ptr<Graph>& g) { return InsertFakeQuantNodes(g); })
+      .def(
+          "_jit_pass_quantlint",
+          [](std::shared_ptr<Graph>& g) { return QuantLinting(g); })
+      .def(
+          "_jit_pass_fold_quant_inputs",
+          [](std::shared_ptr<Graph>& g) {
+            return FoldQuantNodesIntoInputsOutputs(g);
+          })
+      .def(
           "_jit_pass_remove_inplace_ops",
           [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
       .def("_jit_pass_constant_pooling", ConstantPooling)
@@ -352,7 +373,8 @@ void initJITBindings(PyObject* module) {
       .def_property_readonly(
           "name", [](FunctionSchema& self) { return self.name(); })
       .def_property_readonly(
-          "overload_name", [](FunctionSchema& self) { return self.overload_name(); })
+          "overload_name",
+          [](FunctionSchema& self) { return self.overload_name(); })
       .def_property_readonly(
           "arguments", [](FunctionSchema& self) { return self.arguments(); })
       .def_property_readonly(
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
new file mode 100644 (file)
index 0000000..1d46d3e
--- /dev/null
@@ -0,0 +1,38 @@
+#include <torch/csrc/jit/passes/quantization.h>
+
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/alias_analysis.h>
+
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+namespace {} // namespace
+
+void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void InsertObserverNodes(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void QuantLinting(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h
new file mode 100644 (file)
index 0000000..37f558d
--- /dev/null
@@ -0,0 +1,63 @@
+/** \brief This file defines passes used for quantization.
+ *
+ * The passes have python-bindings and can be invoked directly or as a part of
+ * general optimization pipeline (details TBD).
+ */
+#pragma once
+
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+
+/** \brief Replace all FakeQuant nodes with corresponding Quant-Dequant nodes
+ * pair. */
+TORCH_API void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Propagates QParams through nodes that are not supposed to change it.
+ *
+ * An example of such node is `Split`: even though the observed distribution
+ * might be different for input and output tensors, we would like to use input's
+ * qparams for output as well.
+ */
+TORCH_API void PropagateQuantInfo(std::shared_ptr<Graph>& graph);
+
+/** \brief Inserts observer nodes for collecting distribution of values taken by
+ * a tensor.
+ *
+ * The distribution can then be used for computing qparams for quantization.
+ */
+TORCH_API void InsertObserverNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Inserts fake-quant nodes.
+ *
+ * This actually changes the numerical semantics of the original model and thus
+ * we only run it when user explicitly wants that. This pass essentially
+ * performs quantization of the model - later passes only cleanup the IR and
+ * make sure the model runs faster/consumes less memory.
+ *
+ * TODO: This should also take a qparam-map as an input.
+ */
+TORCH_API void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Check that all expected optimizations after fake-quant nodes
+ * insertion actually happened.
+ *
+ * Even though semantically it would be correct to just execute the initial
+ * quant-dequant nodes as is, what we really wanted when we inserted them is to
+ * fuse them into adjacent non-quantized ops resulting in quantized ops. Thus,
+ * if after all the cleanups, optimizations (particularly, fusion) we find
+ * quant-dequant pair in the graph, it indicates that quantization didn't go as
+ * planned.
+ */
+TORCH_API void QuantLinting(std::shared_ptr<Graph>& graph);
+
+/** \brief Quantize model's inputs and outputs.
+ *
+ * This pass folds quant/dequant ops into the input/output tensors, essentially
+ * quantizing these tensors. It's done to reduce model's memory footprint.
+ */
+TORCH_API void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch