self.run_pass('cse', graph)
FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
+ def test_expand_fakequant(self):
+ pass
+
+ def test_expand_propagate_qinfo(self):
+ pass
+
+ def test_expand_insert_observers(self):
+ pass
+
+ def test_expand_insert_fakequant(self):
+ pass
+
+ def test_expand_quantlint(self):
+ pass
+
+ def test_expand_fold_quant_inputs(self):
+ pass
+
def test_shape_analysis_broadcast(self):
def broadcast(a, b):
return a + b
"torch/csrc/jit/passes/lower_tuples.cpp",
"torch/csrc/jit/passes/peephole.cpp",
"torch/csrc/jit/passes/python_print.cpp",
+ "torch/csrc/jit/passes/quantization.cpp",
"torch/csrc/jit/passes/remove_expands.cpp",
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
"torch/csrc/jit/passes/shape_analysis.cpp",
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/quantization.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
return EliminateCommonSubexpression(g); // overload resolution
})
.def(
+ "_jit_pass_expand_fakequant",
+ [](std::shared_ptr<Graph>& g) { return ExpandFakeQuantNodes(g); })
+ .def(
+ "_jit_pass_propagate_qinfo",
+ [](std::shared_ptr<Graph>& g) { return PropagateQuantInfo(g); })
+ .def(
+ "_jit_pass_insert_observers",
+ [](std::shared_ptr<Graph>& g) { return InsertObserverNodes(g); })
+ .def(
+ "_jit_pass_insert_fakequant",
+ [](std::shared_ptr<Graph>& g) { return InsertFakeQuantNodes(g); })
+ .def(
+ "_jit_pass_quantlint",
+ [](std::shared_ptr<Graph>& g) { return QuantLinting(g); })
+ .def(
+ "_jit_pass_fold_quant_inputs",
+ [](std::shared_ptr<Graph>& g) {
+ return FoldQuantNodesIntoInputsOutputs(g);
+ })
+ .def(
"_jit_pass_remove_inplace_ops",
[](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
.def("_jit_pass_constant_pooling", ConstantPooling)
.def_property_readonly(
"name", [](FunctionSchema& self) { return self.name(); })
.def_property_readonly(
- "overload_name", [](FunctionSchema& self) { return self.overload_name(); })
+ "overload_name",
+ [](FunctionSchema& self) { return self.overload_name(); })
.def_property_readonly(
"arguments", [](FunctionSchema& self) { return self.arguments(); })
.def_property_readonly(
--- /dev/null
+#include <torch/csrc/jit/passes/quantization.h>
+
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/alias_analysis.h>
+
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+namespace {} // namespace
+
+void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void InsertObserverNodes(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void QuantLinting(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+} // namespace jit
+} // namespace torch
--- /dev/null
+/** \brief This file defines passes used for quantization.
+ *
+ * The passes have python-bindings and can be invoked directly or as a part of
+ * general optimization pipeline (details TBD).
+ */
+#pragma once
+
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+
+/** \brief Replace all FakeQuant nodes with corresponding Quant-Dequant nodes
+ * pair. */
+TORCH_API void ExpandFakeQuantNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Propagates QParams through nodes that are not supposed to change it.
+ *
+ * An example of such node is `Split`: even though the observed distribution
+ * might be different for input and output tensors, we would like to use input's
+ * qparams for output as well.
+ */
+TORCH_API void PropagateQuantInfo(std::shared_ptr<Graph>& graph);
+
+/** \brief Inserts observer nodes for collecting distribution of values taken by
+ * a tensor.
+ *
+ * The distribution can then be used for computing qparams for quantization.
+ */
+TORCH_API void InsertObserverNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Inserts fake-quant nodes.
+ *
+ * This actually changes the numerical semantics of the original model and thus
+ * we only run it when user explicitly wants that. This pass essentially
+ * performs quantization of the model - later passes only cleanup the IR and
+ * make sure the model runs faster/consumes less memory.
+ *
+ * TODO: This should also take a qparam-map as an input.
+ */
+TORCH_API void InsertFakeQuantNodes(std::shared_ptr<Graph>& graph);
+
+/** \brief Check that all expected optimizations after fake-quant nodes
+ * insertion actually happened.
+ *
+ * Even though semantically it would be correct to just execute the initial
+ * quant-dequant nodes as is, what we really wanted when we inserted them is to
+ * fuse them into adjacent non-quantized ops resulting in quantized ops. Thus,
+ * if after all the cleanups, optimizations (particularly, fusion) we find
+ * quant-dequant pair in the graph, it indicates that quantization didn't go as
+ * planned.
+ */
+TORCH_API void QuantLinting(std::shared_ptr<Graph>& graph);
+
+/** \brief Quantize model's inputs and outputs.
+ *
+ * This pass folds quant/dequant ops into the input/output tensors, essentially
+ * quantizing these tensors. It's done to reduce model's memory footprint.
+ */
+TORCH_API void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch