[qunat][graphmode][fx] Add a separate lower_to_native_backend function for relu ...
authorJerry Zhang <jerryzh@fb.com>
Wed, 25 Aug 2021 04:05:14 +0000 (21:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 04:07:03 +0000 (21:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62861

This PR adds a lower_to_native_backend function to lower a quantized reference model
to a model that uses fbgemm/qnnpack ops. We'll gradually add support and remove
the fbgemm/qnnpack specific handling in quantization_patterns.py

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D30165828

fbshipit-source-id: de1149cd7e7c1840c17c251cd4d35004afd015b7

test/quantization/fx/test_quantize_fx.py
torch/quantization/fx/_lower_to_native_backend.py [new file with mode: 0644]
torch/quantization/fx/convert.py
torch/quantization/fx/lower_to_fbgemm.py [new file with mode: 0644]
torch/quantization/fx/lower_to_qnnpack.py [new file with mode: 0644]
torch/quantization/fx/quantization_patterns.py
torch/quantization/fx/quantized_fusion_patterns_and_replacements.py [new file with mode: 0644]

index bf15a06..1bc6b61 100644 (file)
@@ -2861,6 +2861,28 @@ class TestQuantizeFx(QuantizationTestCase):
             if n.target == "lstm":
                 self.assertEqual(type(n.args[1]), tuple)
 
+    def test_lowering(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                return torch.nn.functional.relu(x)
+
+        m = M().eval()
+        m = prepare_fx(m, {"": default_qconfig})
+        m_copy = copy.deepcopy(m)
+        m = convert_fx(m)
+        m_ref = convert_fx(m_copy, is_reference=True)
+        node_occurrence = {
+            ns.call_function(torch.quantize_per_tensor): 1,
+            ns.call_method("dequantize"): 1
+        }
+        node_occurrence_ref = {
+            ns.call_function(torch.quantize_per_tensor): 2,
+            ns.call_method("dequantize"): 2
+        }
+
+        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
+        self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
+
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
diff --git a/torch/quantization/fx/_lower_to_native_backend.py b/torch/quantization/fx/_lower_to_native_backend.py
new file mode 100644 (file)
index 0000000..a551899
--- /dev/null
@@ -0,0 +1,14 @@
+from torch.fx import subgraph_rewriter
+from .graph_module import QuantizedGraphModule
+from .quantized_fusion_patterns_and_replacements import get_fbgemm_patterns_and_replacements
+
+def _lower_to_native_backend(model: QuantizedGraphModule) -> QuantizedGraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
+    operator signature so they can be lowered with the same function
+    """
+    module_dict = dict(model.named_modules())
+    for pattern, replacement in get_fbgemm_patterns_and_replacements():
+        subgraph_rewriter.replace_pattern(model, pattern, replacement)
+    model.graph.lint()
+    return model
index 671c270..867b0b2 100644 (file)
@@ -45,6 +45,8 @@ from ..utils import (
     activation_dtype,
 )
 
+from .lower_to_fbgemm import lower_to_fbgemm
+
 # weight prepacking ops
 WEIGHT_PREPACK_OPS = {
     torch._ops.ops.quantized.linear_prepack,
@@ -535,4 +537,5 @@ def convert(model: GraphModule, is_reference: bool = False,
     model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes)
     if not is_reference:
         model = fold_weight(model, node_name_to_scope)
+        model = lower_to_fbgemm(model)
     return model
diff --git a/torch/quantization/fx/lower_to_fbgemm.py b/torch/quantization/fx/lower_to_fbgemm.py
new file mode 100644 (file)
index 0000000..fc76d13
--- /dev/null
@@ -0,0 +1,8 @@
+from ._lower_to_native_backend import _lower_to_native_backend
+from .graph_module import QuantizedGraphModule
+
+def lower_to_fbgemm(model: QuantizedGraphModule) -> QuantizedGraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to fbgemm
+    """
+    return _lower_to_native_backend(model)
diff --git a/torch/quantization/fx/lower_to_qnnpack.py b/torch/quantization/fx/lower_to_qnnpack.py
new file mode 100644 (file)
index 0000000..0a0ea9c
--- /dev/null
@@ -0,0 +1,8 @@
+from ._lower_to_native_backend import _lower_to_native_backend
+from .graph_module import QuantizedGraphModule
+
+def lower_to_qnnpack(model: QuantizedGraphModule) -> QuantizedGraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to qnnpack
+    """
+    return _lower_to_native_backend(model)
index 1ce43ca..1a7d714 100644 (file)
@@ -1496,7 +1496,9 @@ class CopyNodeQuantizeHandler(QuantizeHandler):
                 load_arg: Callable,
                 is_reference: bool = False,
                 convert_custom_config_dict: Dict[str, Any] = None) -> Node:
-        if is_reference:
+        # always produce reference pattern for relu
+        is_relu = node.op == "call_function" and node.target == torch.nn.functional.relu
+        if is_reference or is_relu:
             # when activation dtype is torch.float, the node does not require
             # observation
             # e.g. dynamic quantization or weight_only quantization
diff --git a/torch/quantization/fx/quantized_fusion_patterns_and_replacements.py b/torch/quantization/fx/quantized_fusion_patterns_and_replacements.py
new file mode 100644 (file)
index 0000000..07c109e
--- /dev/null
@@ -0,0 +1,31 @@
+import torch
+
+def relu_inplace_pattern(x, scale, zero_point):
+    x = x.dequantize()
+    x = torch.nn.functional.relu(x, inplace=True)
+    x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
+    return x
+
+def relu_non_inplace_pattern(x, scale, zero_point):
+    x = x.dequantize()
+    x = torch.nn.functional.relu(x, inplace=False)
+    x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
+    return x
+
+def relu_replacement(x, scale, zero_point):
+    x = torch.nn.functional.relu(x)
+    return x
+
+
+def _get_all_patterns_and_replacements():
+    return [
+        (relu_inplace_pattern, relu_replacement),
+        (relu_non_inplace_pattern, relu_replacement)
+    ]
+
+
+def get_fbgemm_patterns_and_replacements():
+    return _get_all_patterns_and_replacements()
+
+def get_qnnpack_patterns_and_replacements():
+    return _get_all_patterns_and_replacements()