From 13b95eac550ddf28ac67625a2a1ba3fdaec3c923 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Mon, 25 Mar 2019 17:39:01 -0700 Subject: [PATCH] Add quant-passes stubs. (#18151) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18151 ghimport-source-id: 7d12462971bdf3e5e26a3f150f1fcad05bba1a15 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18152 Initial implementation of InsertObserverNodes pass. * **#18151 Add quant-passes stubs.** gh-metadata: pytorch pytorch 18149 gh/zolotukhinm@gmail.com/1/head Differential Revision: D14584224 fbshipit-source-id: b3d0b5ff797160d5ad23f91f732e627b0129086c --- test/test_jit.py | 18 ++++++++++ tools/build_variables.py | 1 + torch/CMakeLists.txt | 1 + torch/csrc/jit/init.cpp | 24 ++++++++++++- torch/csrc/jit/passes/quantization.cpp | 38 ++++++++++++++++++++ torch/csrc/jit/passes/quantization.h | 63 ++++++++++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/jit/passes/quantization.cpp create mode 100644 torch/csrc/jit/passes/quantization.h diff --git a/test/test_jit.py b/test/test_jit.py index 36d51b9..1c59929 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -991,6 +991,24 @@ class TestJit(JitTestCase): self.run_pass('cse', graph) FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph)) + def test_expand_fakequant(self): + pass + + def test_expand_propagate_qinfo(self): + pass + + def test_expand_insert_observers(self): + pass + + def test_expand_insert_fakequant(self): + pass + + def test_expand_quantlint(self): + pass + + def test_expand_fold_quant_inputs(self): + pass + def test_shape_analysis_broadcast(self): def broadcast(a, b): return a + b diff --git a/tools/build_variables.py b/tools/build_variables.py index 675ae7d..503eb6a 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -81,6 +81,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/lower_tuples.cpp", "torch/csrc/jit/passes/peephole.cpp", "torch/csrc/jit/passes/python_print.cpp", + "torch/csrc/jit/passes/quantization.cpp", "torch/csrc/jit/passes/remove_expands.cpp", "torch/csrc/jit/passes/requires_grad_analysis.cpp", "torch/csrc/jit/passes/shape_analysis.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 71f9908..c2d7783 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -166,6 +166,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index b0888cd..f46dd9f 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +115,26 @@ void initJITBindings(PyObject* module) { return EliminateCommonSubexpression(g); // overload resolution }) .def( + "_jit_pass_expand_fakequant", + [](std::shared_ptr& g) { return ExpandFakeQuantNodes(g); }) + .def( + "_jit_pass_propagate_qinfo", + [](std::shared_ptr& g) { return PropagateQuantInfo(g); }) + .def( + "_jit_pass_insert_observers", + [](std::shared_ptr& g) { return InsertObserverNodes(g); }) + .def( + "_jit_pass_insert_fakequant", + [](std::shared_ptr& g) { return InsertFakeQuantNodes(g); }) + .def( + "_jit_pass_quantlint", + [](std::shared_ptr& g) { return QuantLinting(g); }) + .def( + "_jit_pass_fold_quant_inputs", + [](std::shared_ptr& g) { + return FoldQuantNodesIntoInputsOutputs(g); + }) + .def( "_jit_pass_remove_inplace_ops", [](std::shared_ptr g) { return RemoveInplaceOps(g); }) .def("_jit_pass_constant_pooling", ConstantPooling) @@ -352,7 +373,8 @@ void initJITBindings(PyObject* module) { .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) .def_property_readonly( - "overload_name", [](FunctionSchema& self) { return self.overload_name(); }) + "overload_name", + [](FunctionSchema& self) { return self.overload_name(); }) .def_property_readonly( "arguments", [](FunctionSchema& self) { return self.arguments(); }) .def_property_readonly( diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp new file mode 100644 index 0000000..1d46d3e --- /dev/null +++ b/torch/csrc/jit/passes/quantization.cpp @@ -0,0 +1,38 @@ +#include + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace {} // namespace + +void ExpandFakeQuantNodes(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +void PropagateQuantInfo(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +void InsertObserverNodes(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +void InsertFakeQuantNodes(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +void QuantLinting(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +void FoldQuantNodesIntoInputsOutputs(std::shared_ptr& graph) { + throw std::runtime_error("Pass not implemented yet!"); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h new file mode 100644 index 0000000..37f558d --- /dev/null +++ b/torch/csrc/jit/passes/quantization.h @@ -0,0 +1,63 @@ +/** \brief This file defines passes used for quantization. + * + * The passes have python-bindings and can be invoked directly or as a part of + * general optimization pipeline (details TBD). + */ +#pragma once + +#include + +namespace torch { +namespace jit { + +/** \brief Replace all FakeQuant nodes with corresponding Quant-Dequant nodes + * pair. */ +TORCH_API void ExpandFakeQuantNodes(std::shared_ptr& graph); + +/** \brief Propagates QParams through nodes that are not supposed to change it. + * + * An example of such node is `Split`: even though the observed distribution + * might be different for input and output tensors, we would like to use input's + * qparams for output as well. + */ +TORCH_API void PropagateQuantInfo(std::shared_ptr& graph); + +/** \brief Inserts observer nodes for collecting distribution of values taken by + * a tensor. + * + * The distribution can then be used for computing qparams for quantization. + */ +TORCH_API void InsertObserverNodes(std::shared_ptr& graph); + +/** \brief Inserts fake-quant nodes. + * + * This actually changes the numerical semantics of the original model and thus + * we only run it when user explicitly wants that. This pass essentially + * performs quantization of the model - later passes only cleanup the IR and + * make sure the model runs faster/consumes less memory. + * + * TODO: This should also take a qparam-map as an input. + */ +TORCH_API void InsertFakeQuantNodes(std::shared_ptr& graph); + +/** \brief Check that all expected optimizations after fake-quant nodes + * insertion actually happened. + * + * Even though semantically it would be correct to just execute the initial + * quant-dequant nodes as is, what we really wanted when we inserted them is to + * fuse them into adjacent non-quantized ops resulting in quantized ops. Thus, + * if after all the cleanups, optimizations (particularly, fusion) we find + * quant-dequant pair in the graph, it indicates that quantization didn't go as + * planned. + */ +TORCH_API void QuantLinting(std::shared_ptr& graph); + +/** \brief Quantize model's inputs and outputs. + * + * This pass folds quant/dequant ops into the input/output tensors, essentially + * quantizing these tensors. It's done to reduce model's memory footprint. + */ +TORCH_API void FoldQuantNodesIntoInputsOutputs(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch -- 2.7.4